mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-05-01 23:35:18 +02:00
Add files via upload
This commit is contained in:
+80
-21
@@ -53,6 +53,37 @@ type ResultStorage interface {
|
||||
DeleteResult(executionID string) error
|
||||
}
|
||||
|
||||
type toolCallInterceptorCtxKey struct{}
|
||||
|
||||
type agentConversationIDKey struct{}
|
||||
|
||||
func withAgentConversationID(ctx context.Context, id string) context.Context {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" || ctx == nil {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, agentConversationIDKey{}, id)
|
||||
}
|
||||
|
||||
func agentConversationIDFromContext(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
v, _ := ctx.Value(agentConversationIDKey{}).(string)
|
||||
return v
|
||||
}
|
||||
|
||||
// ToolCallInterceptor allows caller to gate or rewrite tool arguments just before execution.
|
||||
// Returning a non-nil error means the tool call is rejected and execution is skipped.
|
||||
type ToolCallInterceptor func(ctx context.Context, toolName string, args map[string]interface{}, toolCallID string) (map[string]interface{}, error)
|
||||
|
||||
func WithToolCallInterceptor(ctx context.Context, fn ToolCallInterceptor) context.Context {
|
||||
if fn == nil {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, toolCallInterceptorCtxKey{}, fn)
|
||||
}
|
||||
|
||||
// NewAgent 创建新的Agent
|
||||
func NewAgent(cfg *config.OpenAIConfig, agentCfg *config.AgentConfig, mcpServer *mcp.Server, externalMCPMgr *mcp.ExternalMCPManager, logger *zap.Logger, maxIterations int) *Agent {
|
||||
// 如果 maxIterations 为 0 或负数,使用默认值 30
|
||||
@@ -348,7 +379,8 @@ func (a *Agent) EinoSingleAgentSystemInstruction() string {
|
||||
|
||||
// AgentLoopWithProgress 执行Agent循环(带进度回调和对话ID)
|
||||
func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string, callback ProgressCallback, roleTools []string) (*AgentLoopResult, error) {
|
||||
// 设置当前对话ID
|
||||
ctx = withAgentConversationID(ctx, conversationID)
|
||||
// 设置当前对话ID(兼容未走 context 的旧路径;并发会话应以 context 为准)
|
||||
a.mu.Lock()
|
||||
a.currentConversationID = conversationID
|
||||
a.mu.Unlock()
|
||||
@@ -653,22 +685,49 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
"iteration": i + 1,
|
||||
})
|
||||
|
||||
execArgs := toolCall.Function.Arguments
|
||||
if interceptor, ok := ctx.Value(toolCallInterceptorCtxKey{}).(ToolCallInterceptor); ok && interceptor != nil {
|
||||
newArgs, interceptErr := interceptor(ctx, toolCall.Function.Name, execArgs, toolCall.ID)
|
||||
if interceptErr != nil {
|
||||
errorMsg := fmt.Sprintf("工具调用被人工拒绝: %v", interceptErr)
|
||||
messages = append(messages, ChatMessage{
|
||||
Role: "tool",
|
||||
ToolCallID: toolCall.ID,
|
||||
Content: errorMsg,
|
||||
})
|
||||
sendProgress("tool_result", fmt.Sprintf("工具 %s 执行失败", toolCall.Function.Name), map[string]interface{}{
|
||||
"toolName": toolCall.Function.Name,
|
||||
"success": false,
|
||||
"isError": true,
|
||||
"error": errorMsg,
|
||||
"toolCallId": toolCall.ID,
|
||||
"index": idx + 1,
|
||||
"total": len(choice.Message.ToolCalls),
|
||||
"iteration": i + 1,
|
||||
})
|
||||
continue
|
||||
}
|
||||
if newArgs != nil {
|
||||
execArgs = newArgs
|
||||
}
|
||||
}
|
||||
|
||||
// 执行工具
|
||||
toolCtx := context.WithValue(ctx, security.ToolOutputCallbackCtxKey, security.ToolOutputCallback(func(chunk string) {
|
||||
if strings.TrimSpace(chunk) == "" {
|
||||
return
|
||||
}
|
||||
sendProgress("tool_result_delta", chunk, map[string]interface{}{
|
||||
"toolName": toolCall.Function.Name,
|
||||
"toolCallId": toolCall.ID,
|
||||
"index": idx + 1,
|
||||
"total": len(choice.Message.ToolCalls),
|
||||
"iteration": i + 1,
|
||||
"toolName": toolCall.Function.Name,
|
||||
"toolCallId": toolCall.ID,
|
||||
"index": idx + 1,
|
||||
"total": len(choice.Message.ToolCalls),
|
||||
"iteration": i + 1,
|
||||
// success 在最终 tool_result 事件里会以 success/isError 标记为准
|
||||
})
|
||||
}))
|
||||
|
||||
execResult, err := a.executeToolViaMCP(toolCtx, toolCall.Function.Name, toolCall.Function.Arguments)
|
||||
execResult, err := a.executeToolViaMCP(toolCtx, toolCall.Function.Name, execArgs)
|
||||
if err != nil {
|
||||
// 构建详细的错误信息,帮助AI理解问题并做出决策
|
||||
errorMsg := a.formatToolError(toolCall.Function.Name, toolCall.Function.Arguments, err)
|
||||
@@ -746,7 +805,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
// 流式调用OpenAI获取总结(不提供工具,强制AI直接回复)
|
||||
sendProgress("response_start", "", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||
"messageGeneratedBy": "summary",
|
||||
})
|
||||
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
|
||||
@@ -793,7 +852,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
// 流式调用OpenAI获取总结(不提供工具,强制AI直接回复)
|
||||
sendProgress("response_start", "", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||
"messageGeneratedBy": "summary",
|
||||
})
|
||||
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
|
||||
@@ -840,7 +899,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
// 流式调用OpenAI获取总结(不提供工具,强制AI直接回复)
|
||||
sendProgress("response_start", "", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||
"messageGeneratedBy": "max_iter_summary",
|
||||
})
|
||||
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
|
||||
@@ -913,17 +972,13 @@ func (a *Agent) getAvailableTools(roleTools []string) []Tool {
|
||||
defer cancel()
|
||||
|
||||
externalTools, err := a.externalMCPMgr.GetAllTools(ctx)
|
||||
extMap := make(map[string]string)
|
||||
if err != nil {
|
||||
a.logger.Warn("获取外部MCP工具失败", zap.Error(err))
|
||||
} else {
|
||||
// 获取外部MCP配置,用于检查工具启用状态
|
||||
externalMCPConfigs := a.externalMCPMgr.GetConfigs()
|
||||
|
||||
// 清空并重建工具名称映射
|
||||
a.mu.Lock()
|
||||
a.toolNameMapping = make(map[string]string)
|
||||
a.mu.Unlock()
|
||||
|
||||
// 将外部MCP工具添加到工具列表(只添加启用的工具)
|
||||
for _, externalTool := range externalTools {
|
||||
// 外部工具使用 "mcpName::toolName" 作为toolKey
|
||||
@@ -983,9 +1038,7 @@ func (a *Agent) getAvailableTools(roleTools []string) []Tool {
|
||||
openAIName := strings.ReplaceAll(externalTool.Name, "::", "__")
|
||||
|
||||
// 保存名称映射关系(OpenAI格式 -> 原始格式)
|
||||
a.mu.Lock()
|
||||
a.toolNameMapping[openAIName] = externalTool.Name
|
||||
a.mu.Unlock()
|
||||
extMap[openAIName] = externalTool.Name
|
||||
|
||||
tools = append(tools, Tool{
|
||||
Type: "function",
|
||||
@@ -997,6 +1050,9 @@ func (a *Agent) getAvailableTools(roleTools []string) []Tool {
|
||||
})
|
||||
}
|
||||
}
|
||||
a.mu.Lock()
|
||||
a.toolNameMapping = extMap
|
||||
a.mu.Unlock()
|
||||
}
|
||||
|
||||
a.logger.Debug("获取可用工具列表",
|
||||
@@ -1390,9 +1446,12 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map
|
||||
|
||||
// 如果是record_vulnerability工具,自动添加conversation_id
|
||||
if toolName == builtin.ToolRecordVulnerability {
|
||||
a.mu.RLock()
|
||||
conversationID := a.currentConversationID
|
||||
a.mu.RUnlock()
|
||||
conversationID := agentConversationIDFromContext(ctx)
|
||||
if conversationID == "" {
|
||||
a.mu.RLock()
|
||||
conversationID = a.currentConversationID
|
||||
a.mu.RUnlock()
|
||||
}
|
||||
|
||||
if conversationID != "" {
|
||||
args["conversation_id"] = conversationID
|
||||
|
||||
@@ -326,6 +326,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
registerWebshellTools(mcpServer, db, webshellHandler, log.Logger)
|
||||
registerWebshellManagementTools(mcpServer, db, webshellHandler, log.Logger)
|
||||
configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger)
|
||||
agentHandler.SetHitlToolWhitelistSaver(configHandler)
|
||||
externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger)
|
||||
roleHandler := handler.NewRoleHandler(cfg, configPath, log.Logger)
|
||||
skillsHandler := handler.NewSkillsHandler(cfg, configPath, log.Logger)
|
||||
@@ -654,9 +655,15 @@ func setupRoutes(
|
||||
// Eino ADK 单代理(ChatModelAgent + Runner;不依赖 multi_agent.enabled)
|
||||
protected.POST("/eino-agent", agentHandler.EinoSingleAgentLoop)
|
||||
protected.POST("/eino-agent/stream", agentHandler.EinoSingleAgentLoopStream)
|
||||
protected.GET("/hitl/pending", agentHandler.ListHITLPending)
|
||||
protected.POST("/hitl/decision", agentHandler.DecideHITLInterrupt)
|
||||
protected.GET("/hitl/config/:conversationId", agentHandler.GetHITLConversationConfig)
|
||||
protected.PUT("/hitl/config", agentHandler.UpsertHITLConversationConfig)
|
||||
protected.POST("/hitl/tool-whitelist", agentHandler.MergeHITLGlobalToolWhitelist)
|
||||
// Agent Loop 取消与任务列表
|
||||
protected.POST("/agent-loop/cancel", agentHandler.CancelAgentLoop)
|
||||
protected.GET("/agent-loop/tasks", agentHandler.ListAgentTasks)
|
||||
protected.GET("/agent-loop/task-events", agentHandler.SubscribeAgentTaskEvents)
|
||||
protected.GET("/agent-loop/tasks/completed", agentHandler.ListCompletedTasks)
|
||||
|
||||
// Eino DeepAgent 多代理(与单 Agent 并存,需 config.multi_agent.enabled)
|
||||
|
||||
+194
-16
@@ -115,7 +115,9 @@ type AgentHandler struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
tasks *AgentTaskManager
|
||||
taskEventBus *TaskEventBus // 镜像 SSE 事件,供刷新后订阅同一运行中任务
|
||||
batchTaskManager *BatchTaskManager
|
||||
hitlManager *HITLManager
|
||||
config *config.Config // 配置引用,用于获取角色信息
|
||||
knowledgeManager interface { // 知识库管理器接口
|
||||
LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error
|
||||
@@ -124,6 +126,13 @@ type AgentHandler struct {
|
||||
batchCronParser cron.Parser
|
||||
batchRunnerMu sync.Mutex
|
||||
batchRunning map[string]struct{}
|
||||
// hitlWhitelistSaver 侧栏「应用」HITL 时将会话增量白名单合并写入 config.yaml(可选)
|
||||
hitlWhitelistSaver HitlToolWhitelistSaver
|
||||
}
|
||||
|
||||
// HitlToolWhitelistSaver 合并 HITL 免审批工具到全局配置并落盘
|
||||
type HitlToolWhitelistSaver interface {
|
||||
MergeHitlToolWhitelistIntoConfig(add []string) error
|
||||
}
|
||||
|
||||
// NewAgentHandler 创建新的Agent处理器
|
||||
@@ -136,16 +145,24 @@ func NewAgentHandler(agent *agent.Agent, db *database.DB, cfg *config.Config, lo
|
||||
logger.Warn("从数据库加载批量任务队列失败", zap.Error(err))
|
||||
}
|
||||
|
||||
bus := NewTaskEventBus()
|
||||
tm := NewAgentTaskManager()
|
||||
tm.SetTaskEventBus(bus)
|
||||
handler := &AgentHandler{
|
||||
agent: agent,
|
||||
db: db,
|
||||
logger: logger,
|
||||
tasks: NewAgentTaskManager(),
|
||||
tasks: tm,
|
||||
taskEventBus: bus,
|
||||
batchTaskManager: batchTaskManager,
|
||||
config: cfg,
|
||||
hitlManager: NewHITLManager(db, logger),
|
||||
batchCronParser: cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor),
|
||||
batchRunning: make(map[string]struct{}),
|
||||
}
|
||||
if err := handler.hitlManager.EnsureSchema(); err != nil {
|
||||
logger.Warn("初始化 HITL 表失败", zap.Error(err))
|
||||
}
|
||||
go handler.batchQueueSchedulerLoop()
|
||||
return handler
|
||||
}
|
||||
@@ -162,6 +179,11 @@ func (h *AgentHandler) SetAgentsMarkdownDir(absDir string) {
|
||||
h.agentsMarkdownDir = strings.TrimSpace(absDir)
|
||||
}
|
||||
|
||||
// SetHitlToolWhitelistSaver 设置 HITL 白名单落盘(与 ConfigHandler 配合,避免循环引用用接口)
|
||||
func (h *AgentHandler) SetHitlToolWhitelistSaver(s HitlToolWhitelistSaver) {
|
||||
h.hitlWhitelistSaver = s
|
||||
}
|
||||
|
||||
// ChatAttachment 聊天附件(用户上传的文件)
|
||||
type ChatAttachment struct {
|
||||
FileName string `json:"fileName"` // 展示用文件名
|
||||
@@ -177,10 +199,18 @@ type ChatRequest struct {
|
||||
Role string `json:"role,omitempty"` // 角色名称
|
||||
Attachments []ChatAttachment `json:"attachments,omitempty"`
|
||||
WebShellConnectionID string `json:"webshellConnectionId,omitempty"` // WebShell 管理 - AI 助手:当前选中的连接 ID,仅使用 webshell_* 工具
|
||||
Hitl *HITLRequest `json:"hitl,omitempty"`
|
||||
// Orchestration 仅对 /api/multi-agent、/api/multi-agent/stream:deep | plan_execute | supervisor;空则等同 deep。机器人/批量等无请求体时由服务端默认 deep。/api/eino-agent* 不使用此字段。
|
||||
Orchestration string `json:"orchestration,omitempty"`
|
||||
}
|
||||
|
||||
type HITLRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
SensitiveTools []string `json:"sensitiveTools,omitempty"`
|
||||
TimeoutSeconds int `json:"timeoutSeconds,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
maxAttachments = 10
|
||||
chatUploadsDirName = "chat_uploads" // 对话附件保存的根目录(相对当前工作目录)
|
||||
@@ -462,6 +492,11 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
h.activateHITLForConversation(conversationID, req.Hitl)
|
||||
if h.hitlManager != nil {
|
||||
defer h.hitlManager.DeactivateConversation(conversationID)
|
||||
}
|
||||
|
||||
// 优先尝试从保存的ReAct数据恢复历史上下文
|
||||
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID)
|
||||
if err != nil {
|
||||
@@ -566,8 +601,13 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
baseCtx, cancelWithCause := context.WithCancelCause(c.Request.Context())
|
||||
defer cancelWithCause(nil)
|
||||
progressCallback := h.createProgressCallback(baseCtx, cancelWithCause, conversationID, "", nil)
|
||||
baseCtx = h.injectReactHITLInterceptor(baseCtx, cancelWithCause, conversationID, "", nil)
|
||||
|
||||
// 执行Agent Loop,传入历史消息和对话ID(使用包含角色提示词的finalMessage和角色工具列表)
|
||||
result, err := h.agent.AgentLoopWithProgress(c.Request.Context(), finalMessage, agentHistoryMessages, conversationID, nil, roleTools)
|
||||
result, err := h.agent.AgentLoopWithProgress(baseCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools)
|
||||
if err != nil {
|
||||
h.logger.Error("Agent Loop执行失败", zap.Error(err))
|
||||
|
||||
@@ -661,7 +701,7 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
|
||||
if assistantMsg != nil {
|
||||
assistantMessageID = assistantMsg.ID
|
||||
}
|
||||
progressCallback := h.createProgressCallback(conversationID, assistantMessageID, nil)
|
||||
progressCallback := h.createProgressCallback(ctx, nil, conversationID, assistantMessageID, nil)
|
||||
|
||||
useRobotMulti := h.config != nil && h.config.MultiAgent.Enabled && h.config.MultiAgent.RobotUseMultiAgent
|
||||
if useRobotMulti {
|
||||
@@ -755,9 +795,41 @@ type StreamEvent struct {
|
||||
|
||||
// createProgressCallback 创建进度回调函数,用于保存processDetails
|
||||
// sendEventFunc: 可选的流式事件发送函数,如果为nil则不发送流式事件
|
||||
func (h *AgentHandler) createProgressCallback(conversationID, assistantMessageID string, sendEventFunc func(eventType, message string, data interface{})) agent.ProgressCallback {
|
||||
func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun context.CancelCauseFunc, conversationID, assistantMessageID string, sendEventFunc func(eventType, message string, data interface{})) agent.ProgressCallback {
|
||||
// 用于保存tool_call事件中的参数,以便在tool_result时使用
|
||||
toolCallCache := make(map[string]map[string]interface{}) // toolCallId -> arguments
|
||||
skillCallCache := make(map[string]string) // toolCallId -> skillName
|
||||
skillToolName := "skill"
|
||||
if h.config != nil {
|
||||
if customName := strings.TrimSpace(h.config.MultiAgent.EinoSkills.SkillToolName); customName != "" {
|
||||
skillToolName = customName
|
||||
}
|
||||
}
|
||||
|
||||
extractSkillName := func(args map[string]interface{}) string {
|
||||
if len(args) == 0 {
|
||||
return ""
|
||||
}
|
||||
for _, key := range []string{"skill_name", "skillName", "name", "skill", "id", "skill_id", "skillId"} {
|
||||
if v, ok := args[key]; ok {
|
||||
switch vv := v.(type) {
|
||||
case string:
|
||||
if s := strings.TrimSpace(vv); s != "" {
|
||||
return s
|
||||
}
|
||||
case map[string]interface{}:
|
||||
for _, nestedKey := range []string{"name", "id", "skill_name", "skillId"} {
|
||||
if nestedV, nestedOK := vv[nestedKey].(string); nestedOK {
|
||||
if s := strings.TrimSpace(nestedV); s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// thinking_stream_*:不逐条落库,按 streamId 聚合,在后续关键事件前补一条可持久化的 thinking
|
||||
type thinkingBuf struct {
|
||||
@@ -840,6 +912,16 @@ func (h *AgentHandler) createProgressCallback(conversationID, assistantMessageID
|
||||
}
|
||||
}
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(toolName), skillToolName) {
|
||||
toolCallID, _ := dataMap["toolCallId"].(string)
|
||||
if toolCallID != "" {
|
||||
if argumentsObj, ok := dataMap["argumentsObj"].(map[string]interface{}); ok {
|
||||
if skillName := extractSkillName(argumentsObj); skillName != "" {
|
||||
skillCallCache[toolCallID] = skillName
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -953,6 +1035,45 @@ func (h *AgentHandler) createProgressCallback(conversationID, assistantMessageID
|
||||
}
|
||||
}
|
||||
|
||||
// 记录 skills 调用统计(tool_call + tool_result 关联)
|
||||
if eventType == "tool_result" && h.db != nil {
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
toolName, _ := dataMap["toolName"].(string)
|
||||
if strings.EqualFold(strings.TrimSpace(toolName), skillToolName) {
|
||||
toolCallID, _ := dataMap["toolCallId"].(string)
|
||||
skillName := ""
|
||||
if toolCallID != "" {
|
||||
skillName = strings.TrimSpace(skillCallCache[toolCallID])
|
||||
delete(skillCallCache, toolCallID)
|
||||
}
|
||||
if skillName == "" {
|
||||
if argumentsObj, ok := dataMap["argumentsObj"].(map[string]interface{}); ok {
|
||||
skillName = strings.TrimSpace(extractSkillName(argumentsObj))
|
||||
}
|
||||
}
|
||||
if skillName != "" {
|
||||
success, ok := dataMap["success"].(bool)
|
||||
if !ok {
|
||||
if isError, okErr := dataMap["isError"].(bool); okErr {
|
||||
success = !isError
|
||||
}
|
||||
}
|
||||
successCalls := 0
|
||||
failedCalls := 0
|
||||
if success {
|
||||
successCalls = 1
|
||||
} else {
|
||||
failedCalls = 1
|
||||
}
|
||||
now := time.Now()
|
||||
if err := h.db.UpdateSkillStats(skillName, 1, successCalls, failedCalls, &now); err != nil {
|
||||
h.logger.Warn("更新Skills调用统计失败", zap.Error(err), zap.String("skill", skillName))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 子代理回复流式增量不落库;结束时合并为一条 eino_agent_reply
|
||||
if assistantMessageID != "" && eventType == "eino_agent_reply_stream_end" {
|
||||
flushResponsePlan()
|
||||
@@ -1108,6 +1229,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
clientDisconnected := false
|
||||
// 与 sseKeepalive 共用:禁止并发写 ResponseWriter,否则会破坏 chunked 编码(ERR_INVALID_CHUNKED_ENCODING)。
|
||||
var sseWriteMu sync.Mutex
|
||||
var ssePublishConversationID string
|
||||
// 用于快速确认模型是否真的产生了流式 delta
|
||||
var responseDeltaCount int
|
||||
var responseStartLogged bool
|
||||
@@ -1155,7 +1277,24 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 如果客户端已断开,不再发送事件
|
||||
event := StreamEvent{
|
||||
Type: eventType,
|
||||
Message: message,
|
||||
Data: data,
|
||||
}
|
||||
eventJSON, errJSON := json.Marshal(event)
|
||||
if errJSON != nil {
|
||||
eventJSON = []byte(`{"type":"error","message":"marshal failed"}`)
|
||||
}
|
||||
sseLine := make([]byte, 0, len(eventJSON)+8)
|
||||
sseLine = append(sseLine, []byte("data: ")...)
|
||||
sseLine = append(sseLine, eventJSON...)
|
||||
sseLine = append(sseLine, '\n', '\n')
|
||||
if ssePublishConversationID != "" && h.taskEventBus != nil {
|
||||
h.taskEventBus.Publish(ssePublishConversationID, sseLine)
|
||||
}
|
||||
|
||||
// 如果客户端已断开,不再写入 HTTP(镜像订阅仍可收到事件)
|
||||
if clientDisconnected {
|
||||
return
|
||||
}
|
||||
@@ -1168,15 +1307,8 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
default:
|
||||
}
|
||||
|
||||
event := StreamEvent{
|
||||
Type: eventType,
|
||||
Message: message,
|
||||
Data: data,
|
||||
}
|
||||
eventJSON, _ := json.Marshal(event)
|
||||
|
||||
sseWriteMu.Lock()
|
||||
_, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON)
|
||||
_, err := c.Writer.Write(sseLine)
|
||||
if err != nil {
|
||||
sseWriteMu.Unlock()
|
||||
clientDisconnected = true
|
||||
@@ -1220,6 +1352,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
}
|
||||
ssePublishConversationID = conversationID
|
||||
|
||||
// 优先尝试从保存的ReAct数据恢复历史上下文
|
||||
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID)
|
||||
@@ -1350,14 +1483,14 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 创建进度回调函数,复用统一逻辑
|
||||
progressCallback := h.createProgressCallback(conversationID, assistantMessageID, sendEvent)
|
||||
|
||||
// 创建一个独立的上下文用于任务执行,不随HTTP请求取消
|
||||
// 这样即使客户端断开连接(如刷新页面),任务也能继续执行
|
||||
baseCtx, cancelWithCause := context.WithCancelCause(context.Background())
|
||||
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
defer timeoutCancel()
|
||||
defer cancelWithCause(nil)
|
||||
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||
taskCtx = h.injectReactHITLInterceptor(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||
|
||||
if _, err := h.tasks.StartTask(conversationID, req.Message, cancelWithCause); err != nil {
|
||||
var errorMsg string
|
||||
@@ -1606,6 +1739,51 @@ func (h *AgentHandler) CancelAgentLoop(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// SubscribeAgentTaskEvents GET SSE:订阅指定会话当前运行中任务的事件镜像(帧格式与 POST .../stream 一致),用于刷新页面或断线后接续 UI。
|
||||
func (h *AgentHandler) SubscribeAgentTaskEvents(c *gin.Context) {
|
||||
conversationID := strings.TrimSpace(c.Query("conversationId"))
|
||||
if conversationID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "conversationId is required"})
|
||||
return
|
||||
}
|
||||
if h.tasks.GetTask(conversationID) == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "no active task for this conversation"})
|
||||
return
|
||||
}
|
||||
if h.taskEventBus == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "task event bus unavailable"})
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
sub, ch := h.taskEventBus.Subscribe(conversationID)
|
||||
defer h.taskEventBus.Unsubscribe(conversationID, sub)
|
||||
|
||||
flusher, _ := c.Writer.(http.Flusher)
|
||||
ctx := c.Request.Context()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case chunk, ok := <-ch:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, err := c.Writer.Write(chunk); err != nil {
|
||||
return
|
||||
}
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ListAgentTasks 列出所有运行中的任务
|
||||
func (h *AgentHandler) ListAgentTasks(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -2266,7 +2444,7 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
if assistantMsg != nil {
|
||||
assistantMessageID = assistantMsg.ID
|
||||
}
|
||||
progressCallback := h.createProgressCallback(conversationID, assistantMessageID, nil)
|
||||
progressCallback := h.createProgressCallback(context.Background(), nil, conversationID, assistantMessageID, nil)
|
||||
|
||||
// 执行任务(使用包含角色提示词的finalMessage和角色工具列表)
|
||||
h.logger.Info("执行批量任务", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("message", task.Message), zap.String("role", queue.Role), zap.String("conversationId", conversationID))
|
||||
|
||||
@@ -187,6 +187,7 @@ type GetConfigResponse struct {
|
||||
MCP config.MCPConfig `json:"mcp"`
|
||||
Tools []ToolConfigInfo `json:"tools"`
|
||||
Agent config.AgentConfig `json:"agent"`
|
||||
Hitl config.HitlConfig `json:"hitl,omitempty"`
|
||||
Knowledge config.KnowledgeConfig `json:"knowledge"`
|
||||
Robots config.RobotsConfig `json:"robots,omitempty"`
|
||||
MultiAgent config.MultiAgentPublic `json:"multi_agent,omitempty"`
|
||||
@@ -282,6 +283,7 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
|
||||
MCP: h.config.MCP,
|
||||
Tools: tools,
|
||||
Agent: h.config.Agent,
|
||||
Hitl: h.config.Hitl,
|
||||
Knowledge: h.config.Knowledge,
|
||||
Robots: h.config.Robots,
|
||||
MultiAgent: multiPub,
|
||||
@@ -1132,6 +1134,7 @@ func (h *ConfigHandler) saveConfig() error {
|
||||
updateFOFAConfig(root, h.config.FOFA)
|
||||
updateKnowledgeConfig(root, h.config.Knowledge)
|
||||
updateRobotsConfig(root, h.config.Robots)
|
||||
updateHitlConfig(root, h.config.Hitl)
|
||||
updateMultiAgentConfig(root, h.config.MultiAgent)
|
||||
// 更新外部MCP配置(使用external_mcp.go中的函数,同一包中可直接调用)
|
||||
updateExternalMCPConfig(root, h.config.ExternalMCP)
|
||||
@@ -1308,6 +1311,47 @@ func updateKnowledgeConfig(doc *yaml.Node, cfg config.KnowledgeConfig) {
|
||||
setIntInMap(indexingNode, "retry_delay_ms", cfg.Indexing.RetryDelayMs)
|
||||
}
|
||||
|
||||
func mergeHitlToolWhitelistSlice(existing, add []string) []string {
|
||||
seen := make(map[string]struct{})
|
||||
out := make([]string, 0, len(existing)+len(add))
|
||||
for _, list := range [][]string{existing, add} {
|
||||
for _, t := range list {
|
||||
n := strings.ToLower(strings.TrimSpace(t))
|
||||
if n == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[n]; ok {
|
||||
continue
|
||||
}
|
||||
seen[n] = struct{}{}
|
||||
out = append(out, strings.TrimSpace(t))
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// MergeHitlToolWhitelistIntoConfig 将会话侧栏提交的免审批工具名合并进内存配置并写入 config.yaml(与全局白名单去重规则一致:小写键、保留首次出现的原始大小写)。
|
||||
func (h *ConfigHandler) MergeHitlToolWhitelistIntoConfig(add []string) error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
merged := mergeHitlToolWhitelistSlice(h.config.Hitl.ToolWhitelist, add)
|
||||
h.config.Hitl.ToolWhitelist = merged
|
||||
if err := h.saveConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
h.logger.Info("HITL 全局工具白名单已合并写入配置文件",
|
||||
zap.Int("count", len(merged)),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateHitlConfig(doc *yaml.Node, cfg config.HitlConfig) {
|
||||
root := doc.Content[0]
|
||||
hitlNode := ensureMap(root, "hitl")
|
||||
// flow 样式 [a, b, c] 单行展示,工具多时比块序列省行数
|
||||
setFlowStringSliceInMap(hitlNode, "tool_whitelist", cfg.ToolWhitelist)
|
||||
}
|
||||
|
||||
func updateRobotsConfig(doc *yaml.Node, cfg config.RobotsConfig) {
|
||||
root := doc.Content[0]
|
||||
robotsNode := ensureMap(root, "robots")
|
||||
@@ -1418,6 +1462,21 @@ func setStringSliceInMap(mapNode *yaml.Node, key string, values []string) {
|
||||
}
|
||||
}
|
||||
|
||||
func setFlowStringSliceInMap(mapNode *yaml.Node, key string, values []string) {
|
||||
_, valueNode := ensureKeyValue(mapNode, key)
|
||||
valueNode.Kind = yaml.SequenceNode
|
||||
valueNode.Tag = "!!seq"
|
||||
valueNode.Style = yaml.FlowStyle
|
||||
valueNode.Content = nil
|
||||
for _, v := range values {
|
||||
valueNode.Content = append(valueNode.Content, &yaml.Node{
|
||||
Kind: yaml.ScalarNode,
|
||||
Tag: "!!str",
|
||||
Value: v,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func setIntInMap(mapNode *yaml.Node, key string, value int) {
|
||||
_, valueNode := ensureKeyValue(mapNode, key)
|
||||
valueNode.Kind = yaml.ScalarNode
|
||||
|
||||
@@ -41,11 +41,24 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
var baseCtx context.Context
|
||||
clientDisconnected := false
|
||||
var sseWriteMu sync.Mutex
|
||||
var ssePublishConversationID string
|
||||
sendEvent := func(eventType, message string, data interface{}) {
|
||||
if clientDisconnected {
|
||||
if eventType == "error" && baseCtx != nil && errors.Is(context.Cause(baseCtx), ErrTaskCancelled) {
|
||||
return
|
||||
}
|
||||
if eventType == "error" && baseCtx != nil && errors.Is(context.Cause(baseCtx), ErrTaskCancelled) {
|
||||
ev := StreamEvent{Type: eventType, Message: message, Data: data}
|
||||
b, errMarshal := json.Marshal(ev)
|
||||
if errMarshal != nil {
|
||||
b = []byte(`{"type":"error","message":"marshal failed"}`)
|
||||
}
|
||||
sseLine := make([]byte, 0, len(b)+8)
|
||||
sseLine = append(sseLine, []byte("data: ")...)
|
||||
sseLine = append(sseLine, b...)
|
||||
sseLine = append(sseLine, '\n', '\n')
|
||||
if ssePublishConversationID != "" && h.taskEventBus != nil {
|
||||
h.taskEventBus.Publish(ssePublishConversationID, sseLine)
|
||||
}
|
||||
if clientDisconnected {
|
||||
return
|
||||
}
|
||||
select {
|
||||
@@ -54,10 +67,8 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
return
|
||||
default:
|
||||
}
|
||||
ev := StreamEvent{Type: eventType, Message: message, Data: data}
|
||||
b, _ := json.Marshal(ev)
|
||||
sseWriteMu.Lock()
|
||||
_, err := fmt.Fprintf(c.Writer, "data: %s\n\n", b)
|
||||
_, err := c.Writer.Write(sseLine)
|
||||
if err != nil {
|
||||
sseWriteMu.Unlock()
|
||||
clientDisconnected = true
|
||||
@@ -81,6 +92,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
sendEvent("done", "", nil)
|
||||
return
|
||||
}
|
||||
ssePublishConversationID = prep.ConversationID
|
||||
if prep.CreatedNew {
|
||||
sendEvent("conversation", "会话已创建", map[string]interface{}{
|
||||
"conversationId": prep.ConversationID,
|
||||
@@ -89,6 +101,10 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
|
||||
conversationID := prep.ConversationID
|
||||
assistantMessageID := prep.AssistantMessageID
|
||||
h.activateHITLForConversation(conversationID, req.Hitl)
|
||||
if h.hitlManager != nil {
|
||||
defer h.hitlManager.DeactivateConversation(conversationID)
|
||||
}
|
||||
|
||||
if prep.UserMessageID != "" {
|
||||
sendEvent("message_saved", "", map[string]interface{}{
|
||||
@@ -97,13 +113,15 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
progressCallback := h.createProgressCallback(conversationID, assistantMessageID, sendEvent)
|
||||
|
||||
var cancelWithCause context.CancelCauseFunc
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
defer timeoutCancel()
|
||||
defer cancelWithCause(nil)
|
||||
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||
taskCtx = multiagent.WithHITLToolInterceptor(taskCtx, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
|
||||
})
|
||||
|
||||
if _, err := h.tasks.StartTask(conversationID, req.Message, cancelWithCause); err != nil {
|
||||
var errorMsg string
|
||||
@@ -136,6 +154,8 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
defer close(stopKeepalive)
|
||||
|
||||
if h.config == nil {
|
||||
taskStatus = "failed"
|
||||
h.tasks.UpdateTaskStatus(conversationID, taskStatus)
|
||||
sendEvent("error", "服务器配置未加载", nil)
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
return
|
||||
@@ -166,7 +186,24 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
}
|
||||
sendEvent("cancelled", cancelMsg, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"messageId": assistantMessageID,
|
||||
"messageId": assistantMessageID,
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
return
|
||||
}
|
||||
|
||||
if errors.Is(runErr, context.DeadlineExceeded) || errors.Is(context.Cause(taskCtx), context.DeadlineExceeded) {
|
||||
taskStatus = "timeout"
|
||||
h.tasks.UpdateTaskStatus(conversationID, taskStatus)
|
||||
timeoutMsg := "任务执行超时,已自动终止。"
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", timeoutMsg, assistantMessageID)
|
||||
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil)
|
||||
}
|
||||
sendEvent("error", timeoutMsg, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"messageId": assistantMessageID,
|
||||
"errorType": "timeout",
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
return
|
||||
@@ -232,12 +269,22 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
h.activateHITLForConversation(prep.ConversationID, req.Hitl)
|
||||
if h.hitlManager != nil {
|
||||
defer h.hitlManager.DeactivateConversation(prep.ConversationID)
|
||||
}
|
||||
|
||||
var progressBuf strings.Builder
|
||||
progressCallback := func(eventType, message string, data interface{}) {
|
||||
progressCallbackRaw := func(eventType, message string, data interface{}) {
|
||||
progressBuf.WriteString(eventType)
|
||||
progressBuf.WriteByte('\n')
|
||||
}
|
||||
baseCtx, cancelWithCause := context.WithCancelCause(c.Request.Context())
|
||||
defer cancelWithCause(nil)
|
||||
progressCallback := h.createProgressCallback(baseCtx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, progressCallbackRaw)
|
||||
baseCtx = multiagent.WithHITLToolInterceptor(baseCtx, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||
return h.interceptHITLForEinoTool(ctx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, nil, toolName, arguments)
|
||||
})
|
||||
|
||||
if h.config == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "服务器配置未加载"})
|
||||
@@ -245,7 +292,7 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
||||
}
|
||||
|
||||
result, runErr := multiagent.RunEinoSingleChatModelAgent(
|
||||
c.Request.Context(),
|
||||
baseCtx,
|
||||
h.config,
|
||||
&h.config.MultiAgent,
|
||||
h.agent,
|
||||
@@ -279,10 +326,10 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"response": result.Response,
|
||||
"conversationId": prep.ConversationID,
|
||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||
"response": result.Response,
|
||||
"conversationId": prep.ConversationID,
|
||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||
"assistantMessageId": prep.AssistantMessageID,
|
||||
"agentMode": "eino_single",
|
||||
"agentMode": "eino_single",
|
||||
})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,748 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/multiagent"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type hitlRuntimeConfig struct {
|
||||
Enabled bool
|
||||
Mode string
|
||||
SensitiveTools map[string]struct{}
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
type hitlDecision struct {
|
||||
Decision string
|
||||
Comment string
|
||||
EditedArguments map[string]interface{}
|
||||
}
|
||||
|
||||
type pendingInterrupt struct {
|
||||
ConversationID string
|
||||
InterruptID string
|
||||
Mode string
|
||||
ToolName string
|
||||
ToolCallID string
|
||||
decideCh chan hitlDecision
|
||||
}
|
||||
|
||||
type HITLManager struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
|
||||
mu sync.RWMutex
|
||||
runtime map[string]hitlRuntimeConfig
|
||||
pending map[string]*pendingInterrupt
|
||||
}
|
||||
|
||||
func NewHITLManager(db *database.DB, logger *zap.Logger) *HITLManager {
|
||||
return &HITLManager{
|
||||
db: db,
|
||||
logger: logger,
|
||||
runtime: make(map[string]hitlRuntimeConfig),
|
||||
pending: make(map[string]*pendingInterrupt),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *HITLManager) EnsureSchema() error {
|
||||
if _, err := m.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS hitl_interrupts (
|
||||
id TEXT PRIMARY KEY,
|
||||
conversation_id TEXT NOT NULL,
|
||||
message_id TEXT,
|
||||
mode TEXT NOT NULL,
|
||||
tool_name TEXT NOT NULL,
|
||||
tool_call_id TEXT,
|
||||
payload TEXT,
|
||||
status TEXT NOT NULL,
|
||||
decision TEXT,
|
||||
decision_comment TEXT,
|
||||
created_at DATETIME NOT NULL,
|
||||
decided_at DATETIME
|
||||
);`); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := m.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS hitl_conversation_configs (
|
||||
conversation_id TEXT PRIMARY KEY,
|
||||
enabled INTEGER NOT NULL DEFAULT 0,
|
||||
mode TEXT NOT NULL DEFAULT 'off',
|
||||
sensitive_tools TEXT NOT NULL DEFAULT '[]',
|
||||
timeout_seconds INTEGER NOT NULL DEFAULT 300,
|
||||
updated_at DATETIME NOT NULL
|
||||
);`)
|
||||
return err
|
||||
}
|
||||
|
||||
func normalizeHitlMode(mode string) string {
|
||||
v := strings.ToLower(strings.TrimSpace(mode))
|
||||
if v == "" {
|
||||
return "approval"
|
||||
}
|
||||
switch v {
|
||||
case "off":
|
||||
return "off"
|
||||
case "feedback", "followup":
|
||||
return "approval"
|
||||
case "approval", "review_edit":
|
||||
return v
|
||||
default:
|
||||
return "approval"
|
||||
}
|
||||
}
|
||||
|
||||
func (m *HITLManager) ActivateConversation(conversationID string, req *HITLRequest) {
|
||||
if req == nil || !req.Enabled {
|
||||
m.DeactivateConversation(conversationID)
|
||||
return
|
||||
}
|
||||
tools := make(map[string]struct{})
|
||||
for _, t := range req.SensitiveTools {
|
||||
n := strings.ToLower(strings.TrimSpace(t))
|
||||
if n != "" {
|
||||
tools[n] = struct{}{}
|
||||
}
|
||||
}
|
||||
timeout := 5 * time.Minute
|
||||
if req.TimeoutSeconds > 0 {
|
||||
timeout = time.Duration(req.TimeoutSeconds) * time.Second
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.runtime[conversationID] = hitlRuntimeConfig{
|
||||
Enabled: true,
|
||||
Mode: normalizeHitlMode(req.Mode),
|
||||
SensitiveTools: tools,
|
||||
Timeout: timeout,
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func (m *HITLManager) DeactivateConversation(conversationID string) {
|
||||
m.mu.Lock()
|
||||
delete(m.runtime, conversationID)
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
// hitlConfigGlobalToolWhitelist 来自 config.yaml hitl.tool_whitelist(去重、去空)。
|
||||
func (h *AgentHandler) hitlConfigGlobalToolWhitelist() []string {
|
||||
if h == nil || h.config == nil {
|
||||
return nil
|
||||
}
|
||||
raw := h.config.Hitl.ToolWhitelist
|
||||
if len(raw) == 0 {
|
||||
return nil
|
||||
}
|
||||
seen := make(map[string]struct{})
|
||||
out := make([]string, 0, len(raw))
|
||||
for _, t := range raw {
|
||||
n := strings.ToLower(strings.TrimSpace(t))
|
||||
if n == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[n]; ok {
|
||||
continue
|
||||
}
|
||||
seen[n] = struct{}{}
|
||||
out = append(out, strings.TrimSpace(t))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// hitlRequestWithMergedConfigWhitelist 将会话/API 中的白名单与 config.yaml 全局白名单合并(并集),仅用于运行时 Activate;不写入数据库。
|
||||
func (h *AgentHandler) hitlRequestWithMergedConfigWhitelist(req *HITLRequest) *HITLRequest {
|
||||
gw := h.hitlConfigGlobalToolWhitelist()
|
||||
if len(gw) == 0 {
|
||||
return req
|
||||
}
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
seen := make(map[string]struct{})
|
||||
union := make([]string, 0, len(gw)+len(req.SensitiveTools))
|
||||
for _, t := range gw {
|
||||
n := strings.ToLower(strings.TrimSpace(t))
|
||||
if n == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[n]; ok {
|
||||
continue
|
||||
}
|
||||
seen[n] = struct{}{}
|
||||
union = append(union, strings.TrimSpace(t))
|
||||
}
|
||||
for _, t := range req.SensitiveTools {
|
||||
n := strings.ToLower(strings.TrimSpace(t))
|
||||
if n == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[n]; ok {
|
||||
continue
|
||||
}
|
||||
seen[n] = struct{}{}
|
||||
union = append(union, strings.TrimSpace(t))
|
||||
}
|
||||
out := *req
|
||||
out.SensitiveTools = union
|
||||
return &out
|
||||
}
|
||||
|
||||
func (m *HITLManager) shouldInterrupt(conversationID, toolName string) (hitlRuntimeConfig, bool) {
|
||||
m.mu.RLock()
|
||||
cfg, ok := m.runtime[conversationID]
|
||||
m.mu.RUnlock()
|
||||
if !ok || !cfg.Enabled {
|
||||
return hitlRuntimeConfig{}, false
|
||||
}
|
||||
// 语义:SensitiveTools 现在作为“白名单(免审批工具)”
|
||||
// 空白名单 => 全部工具都需要审批
|
||||
if len(cfg.SensitiveTools) == 0 {
|
||||
return cfg, true
|
||||
}
|
||||
_, inWhitelist := cfg.SensitiveTools[strings.ToLower(strings.TrimSpace(toolName))]
|
||||
return cfg, !inWhitelist
|
||||
}
|
||||
|
||||
func (m *HITLManager) CreatePendingInterrupt(conversationID, assistantMessageID, mode, toolName, toolCallID, payload string) (*pendingInterrupt, error) {
|
||||
now := time.Now()
|
||||
id := "hitl_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
if _, err := m.db.Exec(`INSERT INTO hitl_interrupts
|
||||
(id, conversation_id, message_id, mode, tool_name, tool_call_id, payload, status, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, 'pending', ?)`,
|
||||
id, conversationID, assistantMessageID, mode, toolName, toolCallID, payload, now); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 刷新页面后侧栏依赖 DB 配置;若仅内存 Activate 未落库,会导致「有待审批却显示关闭」
|
||||
_ = m.ensureConversationHITLModePersisted(conversationID, mode)
|
||||
p := &pendingInterrupt{
|
||||
ConversationID: conversationID,
|
||||
InterruptID: id,
|
||||
Mode: normalizeHitlMode(mode),
|
||||
ToolName: toolName,
|
||||
ToolCallID: toolCallID,
|
||||
decideCh: make(chan hitlDecision, 1),
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.pending[id] = p
|
||||
m.mu.Unlock()
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// ensureConversationHITLModePersisted 在产生待审批时把 mode 写入 hitl_conversation_configs,避免刷新后 GET 配置仍为关闭。
|
||||
func (m *HITLManager) ensureConversationHITLModePersisted(conversationID, interruptMode string) error {
|
||||
if strings.TrimSpace(conversationID) == "" {
|
||||
return nil
|
||||
}
|
||||
nm := normalizeHitlMode(interruptMode)
|
||||
if nm == "off" {
|
||||
return nil
|
||||
}
|
||||
cfg, err := m.LoadConversationConfig(conversationID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if cfg.Enabled && normalizeHitlMode(cfg.Mode) == nm {
|
||||
return nil
|
||||
}
|
||||
cfg.Enabled = true
|
||||
cfg.Mode = nm
|
||||
if cfg.TimeoutSeconds <= 0 {
|
||||
cfg.TimeoutSeconds = 300
|
||||
}
|
||||
return m.SaveConversationConfig(conversationID, cfg)
|
||||
}
|
||||
|
||||
// PendingHITLInterruptMode 返回该会话最新一条 pending 中断的协同模式(用于 GET 配置时与库内「关闭」状态对齐)。
|
||||
func (m *HITLManager) PendingHITLInterruptMode(conversationID string) (string, bool) {
|
||||
if strings.TrimSpace(conversationID) == "" {
|
||||
return "", false
|
||||
}
|
||||
var mode string
|
||||
err := m.db.QueryRow(`SELECT mode FROM hitl_interrupts WHERE conversation_id = ? AND status = 'pending' ORDER BY created_at DESC LIMIT 1`, conversationID).
|
||||
Scan(&mode)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return "", false
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
mode = strings.TrimSpace(mode)
|
||||
if mode == "" {
|
||||
return "", false
|
||||
}
|
||||
return mode, true
|
||||
}
|
||||
|
||||
func hitlStoredConfigEffective(cfg *HITLRequest) bool {
|
||||
if cfg == nil {
|
||||
return false
|
||||
}
|
||||
if cfg.Enabled {
|
||||
return true
|
||||
}
|
||||
return normalizeHitlMode(cfg.Mode) != "off"
|
||||
}
|
||||
|
||||
func (m *HITLManager) ResolveInterrupt(interruptID, decision, comment string, editedArguments map[string]interface{}) error {
|
||||
decision = strings.ToLower(strings.TrimSpace(decision))
|
||||
if decision != "approve" && decision != "reject" {
|
||||
return errors.New("decision must be approve/reject")
|
||||
}
|
||||
m.mu.RLock()
|
||||
p, ok := m.pending[interruptID]
|
||||
m.mu.RUnlock()
|
||||
if !ok {
|
||||
return errors.New("interrupt not found or already resolved")
|
||||
}
|
||||
d := hitlDecision{
|
||||
Decision: decision,
|
||||
Comment: strings.TrimSpace(comment),
|
||||
EditedArguments: editedArguments,
|
||||
}
|
||||
select {
|
||||
case p.decideCh <- d:
|
||||
return nil
|
||||
default:
|
||||
return errors.New("interrupt already resolved or decision channel busy")
|
||||
}
|
||||
}
|
||||
|
||||
func (m *HITLManager) SaveConversationConfig(conversationID string, req *HITLRequest) error {
|
||||
if strings.TrimSpace(conversationID) == "" {
|
||||
return errors.New("conversationId is required")
|
||||
}
|
||||
if req == nil {
|
||||
req = &HITLRequest{Enabled: false, Mode: "off", TimeoutSeconds: 300}
|
||||
}
|
||||
mode := normalizeHitlMode(req.Mode)
|
||||
if !req.Enabled {
|
||||
mode = "off"
|
||||
}
|
||||
tools, _ := json.Marshal(req.SensitiveTools)
|
||||
timeout := req.TimeoutSeconds
|
||||
if timeout <= 0 {
|
||||
timeout = 300
|
||||
}
|
||||
_, err := m.db.Exec(`INSERT INTO hitl_conversation_configs
|
||||
(conversation_id, enabled, mode, sensitive_tools, timeout_seconds, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(conversation_id) DO UPDATE SET
|
||||
enabled=excluded.enabled, mode=excluded.mode, sensitive_tools=excluded.sensitive_tools, timeout_seconds=excluded.timeout_seconds, updated_at=excluded.updated_at`,
|
||||
conversationID, boolToInt(req.Enabled), mode, string(tools), timeout, time.Now())
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *HITLManager) LoadConversationConfig(conversationID string) (*HITLRequest, error) {
|
||||
var enabledInt int
|
||||
var mode, toolsJSON string
|
||||
var timeout int
|
||||
err := m.db.QueryRow(`SELECT enabled, mode, sensitive_tools, timeout_seconds FROM hitl_conversation_configs WHERE conversation_id = ?`, conversationID).
|
||||
Scan(&enabledInt, &mode, &toolsJSON, &timeout)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return &HITLRequest{Enabled: false, Mode: "off", SensitiveTools: []string{}, TimeoutSeconds: 300}, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tools := make([]string, 0)
|
||||
_ = json.Unmarshal([]byte(toolsJSON), &tools)
|
||||
return &HITLRequest{
|
||||
Enabled: enabledInt == 1,
|
||||
Mode: mode,
|
||||
SensitiveTools: tools,
|
||||
TimeoutSeconds: timeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *HITLManager) waitDecision(ctx context.Context, p *pendingInterrupt, timeout time.Duration) (hitlDecision, error) {
|
||||
defer func() {
|
||||
m.mu.Lock()
|
||||
delete(m.pending, p.InterruptID)
|
||||
m.mu.Unlock()
|
||||
}()
|
||||
select {
|
||||
case d := <-p.decideCh:
|
||||
// 只有 review_edit 模式允许改参;其他模式一律忽略 edited arguments
|
||||
if p.Mode != "review_edit" && len(d.EditedArguments) > 0 {
|
||||
d.EditedArguments = nil
|
||||
}
|
||||
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='decided', decision=?, decision_comment=?, decided_at=? WHERE id=?`,
|
||||
d.Decision, d.Comment, time.Now(), p.InterruptID)
|
||||
return d, nil
|
||||
case <-time.After(timeout):
|
||||
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='timeout', decision='approve', decision_comment='timeout auto approve', decided_at=? WHERE id=?`,
|
||||
time.Now(), p.InterruptID)
|
||||
return hitlDecision{Decision: "approve", Comment: "timeout auto approve"}, nil
|
||||
case <-ctx.Done():
|
||||
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='cancelled', decision='reject', decision_comment='task cancelled', decided_at=? WHERE id=?`,
|
||||
time.Now(), p.InterruptID)
|
||||
return hitlDecision{Decision: "reject", Comment: "task cancelled"}, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (h *AgentHandler) activateHITLForConversation(conversationID string, req *HITLRequest) {
|
||||
if h.hitlManager == nil {
|
||||
return
|
||||
}
|
||||
if req == nil {
|
||||
cfg, err := h.hitlManager.LoadConversationConfig(conversationID)
|
||||
if err == nil {
|
||||
req = cfg
|
||||
}
|
||||
}
|
||||
h.hitlManager.ActivateConversation(conversationID, h.hitlRequestWithMergedConfigWhitelist(req))
|
||||
}
|
||||
|
||||
func (h *AgentHandler) waitHITLApproval(runCtx context.Context, cancelRun context.CancelCauseFunc, conversationID, assistantMessageID, toolName, toolCallID string, payload map[string]interface{}, sendEventFunc func(eventType, message string, data interface{})) (*hitlDecision, error) {
|
||||
cfg, need := h.hitlManager.shouldInterrupt(conversationID, toolName)
|
||||
if !need {
|
||||
return nil, nil
|
||||
}
|
||||
payloadRaw, _ := json.Marshal(payload)
|
||||
p, err := h.hitlManager.CreatePendingInterrupt(conversationID, assistantMessageID, cfg.Mode, toolName, toolCallID, string(payloadRaw))
|
||||
if err != nil {
|
||||
h.logger.Warn("创建 HITL 中断失败", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
if sendEventFunc != nil {
|
||||
sendEventFunc("hitl_interrupt", "命中人机协同审批", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"interruptId": p.InterruptID,
|
||||
"mode": cfg.Mode,
|
||||
"toolName": toolName,
|
||||
"toolCallId": toolCallID,
|
||||
"payload": payload,
|
||||
})
|
||||
}
|
||||
d, waitErr := h.hitlManager.waitDecision(runCtx, p, cfg.Timeout)
|
||||
if waitErr != nil {
|
||||
if cancelRun != nil && (errors.Is(waitErr, context.Canceled) || errors.Is(waitErr, context.DeadlineExceeded)) {
|
||||
cause := context.Cause(runCtx)
|
||||
switch {
|
||||
case errors.Is(cause, ErrTaskCancelled):
|
||||
cancelRun(ErrTaskCancelled)
|
||||
case cause != nil:
|
||||
cancelRun(cause)
|
||||
case errors.Is(waitErr, context.DeadlineExceeded):
|
||||
cancelRun(context.DeadlineExceeded)
|
||||
default:
|
||||
cancelRun(ErrTaskCancelled)
|
||||
}
|
||||
}
|
||||
return nil, waitErr
|
||||
}
|
||||
if d.Decision == "reject" {
|
||||
if sendEventFunc != nil {
|
||||
sendEventFunc("hitl_rejected", "人工拒绝本次工具调用,模型将基于反馈继续迭代", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"interruptId": p.InterruptID,
|
||||
"toolName": toolName,
|
||||
"comment": d.Comment,
|
||||
})
|
||||
}
|
||||
return &d, nil
|
||||
}
|
||||
if sendEventFunc != nil {
|
||||
sendEventFunc("hitl_resumed", "人工确认通过,继续执行", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"interruptId": p.InterruptID,
|
||||
"toolName": toolName,
|
||||
"comment": d.Comment,
|
||||
"editedArgs": d.EditedArguments,
|
||||
})
|
||||
}
|
||||
return &d, nil
|
||||
}
|
||||
|
||||
func (h *AgentHandler) handleHITLToolCall(runCtx context.Context, cancelRun context.CancelCauseFunc, conversationID, assistantMessageID string, data map[string]interface{}, sendEventFunc func(eventType, message string, data interface{})) {
|
||||
if h.hitlManager == nil {
|
||||
return
|
||||
}
|
||||
toolName, _ := data["toolName"].(string)
|
||||
toolCallID, _ := data["toolCallId"].(string)
|
||||
d, err := h.waitHITLApproval(runCtx, cancelRun, conversationID, assistantMessageID, toolName, toolCallID, data, sendEventFunc)
|
||||
if err != nil || d == nil {
|
||||
return
|
||||
}
|
||||
if len(d.EditedArguments) > 0 {
|
||||
if argsObj, ok := data["argumentsObj"].(map[string]interface{}); ok {
|
||||
for k := range argsObj {
|
||||
delete(argsObj, k)
|
||||
}
|
||||
for k, v := range d.EditedArguments {
|
||||
argsObj[k] = v
|
||||
}
|
||||
if b, mErr := json.Marshal(argsObj); mErr == nil {
|
||||
data["arguments"] = string(b)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *AgentHandler) ListHITLPending(c *gin.Context) {
|
||||
conversationID := strings.TrimSpace(c.Query("conversationId"))
|
||||
status := strings.TrimSpace(c.Query("status"))
|
||||
if status == "" {
|
||||
status = "pending"
|
||||
}
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20"))
|
||||
pageSize = int(math.Max(1, math.Min(float64(pageSize), 200)))
|
||||
offset := (page - 1) * pageSize
|
||||
q := `SELECT id, conversation_id, message_id, mode, tool_name, tool_call_id, payload, status, decision, decision_comment, created_at, decided_at FROM hitl_interrupts WHERE 1=1`
|
||||
args := []interface{}{}
|
||||
if conversationID != "" {
|
||||
q += " AND conversation_id = ?"
|
||||
args = append(args, conversationID)
|
||||
}
|
||||
if status != "all" {
|
||||
q += " AND status = ?"
|
||||
args = append(args, status)
|
||||
}
|
||||
q += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
||||
args = append(args, pageSize, offset)
|
||||
rows, err := h.db.Query(q, args...)
|
||||
if err != nil {
|
||||
c.JSON(500, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
items := make([]map[string]interface{}, 0)
|
||||
for rows.Next() {
|
||||
var id, cid, mode, toolName, toolCallID, payload, rowStatus string
|
||||
var messageID sql.NullString
|
||||
var decision, comment sql.NullString
|
||||
var createdAt time.Time
|
||||
var decidedAt sql.NullTime
|
||||
if err := rows.Scan(&id, &cid, &messageID, &mode, &toolName, &toolCallID, &payload, &rowStatus, &decision, &comment, &createdAt, &decidedAt); err != nil {
|
||||
continue
|
||||
}
|
||||
msgID := ""
|
||||
if messageID.Valid {
|
||||
msgID = messageID.String
|
||||
}
|
||||
items = append(items, map[string]interface{}{
|
||||
"id": id,
|
||||
"conversationId": cid,
|
||||
"messageId": msgID,
|
||||
"mode": mode,
|
||||
"toolName": toolName,
|
||||
"toolCallId": toolCallID,
|
||||
"payload": payload,
|
||||
"status": rowStatus,
|
||||
"decision": decision.String,
|
||||
"comment": comment.String,
|
||||
"createdAt": createdAt,
|
||||
"decidedAt": func() interface{} {
|
||||
if decidedAt.Valid {
|
||||
return decidedAt.Time
|
||||
}
|
||||
return nil
|
||||
}(),
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"items": items, "page": page, "pageSize": pageSize})
|
||||
}
|
||||
|
||||
type hitlDecisionReq struct {
|
||||
InterruptID string `json:"interruptId" binding:"required"`
|
||||
Decision string `json:"decision" binding:"required"`
|
||||
Comment string `json:"comment,omitempty"`
|
||||
EditedArguments map[string]interface{} `json:"editedArguments,omitempty"`
|
||||
}
|
||||
|
||||
func (h *AgentHandler) DecideHITLInterrupt(c *gin.Context) {
|
||||
var req hitlDecisionReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(400, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.hitlManager == nil {
|
||||
c.JSON(500, gin.H{"error": "hitl manager unavailable"})
|
||||
return
|
||||
}
|
||||
if err := h.hitlManager.ResolveInterrupt(req.InterruptID, req.Decision, req.Comment, req.EditedArguments); err != nil {
|
||||
c.JSON(http.StatusConflict, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
}
|
||||
|
||||
func (h *AgentHandler) interceptHITLForEinoTool(runCtx context.Context, cancelRun context.CancelCauseFunc, conversationID, assistantMessageID string, sendEventFunc func(eventType, message string, data interface{}), toolName, arguments string) (string, error) {
|
||||
payload := map[string]interface{}{
|
||||
"toolName": toolName,
|
||||
"arguments": arguments,
|
||||
"source": "eino_middleware",
|
||||
"toolCallId": "",
|
||||
}
|
||||
var argsObj map[string]interface{}
|
||||
if strings.TrimSpace(arguments) != "" {
|
||||
_ = json.Unmarshal([]byte(arguments), &argsObj)
|
||||
if argsObj != nil {
|
||||
payload["argumentsObj"] = argsObj
|
||||
}
|
||||
}
|
||||
d, err := h.waitHITLApproval(runCtx, cancelRun, conversationID, assistantMessageID, toolName, "", payload, sendEventFunc)
|
||||
if err != nil || d == nil {
|
||||
return arguments, err
|
||||
}
|
||||
if d.Decision == "reject" {
|
||||
return arguments, multiagent.NewHumanRejectError(d.Comment)
|
||||
}
|
||||
if len(d.EditedArguments) > 0 {
|
||||
edited, mErr := json.Marshal(d.EditedArguments)
|
||||
if mErr == nil {
|
||||
return string(edited), nil
|
||||
}
|
||||
}
|
||||
return arguments, nil
|
||||
}
|
||||
|
||||
func (h *AgentHandler) interceptHITLForReactTool(runCtx context.Context, cancelRun context.CancelCauseFunc, conversationID, assistantMessageID string, sendEventFunc func(eventType, message string, data interface{}), toolName string, arguments map[string]interface{}, toolCallID string) (map[string]interface{}, error) {
|
||||
payload := map[string]interface{}{
|
||||
"toolName": toolName,
|
||||
"argumentsObj": arguments,
|
||||
"toolCallId": toolCallID,
|
||||
"source": "react_pre_exec",
|
||||
}
|
||||
d, err := h.waitHITLApproval(runCtx, cancelRun, conversationID, assistantMessageID, toolName, toolCallID, payload, sendEventFunc)
|
||||
if err != nil || d == nil {
|
||||
return arguments, err
|
||||
}
|
||||
if d.Decision == "reject" {
|
||||
comment := strings.TrimSpace(d.Comment)
|
||||
if comment == "" {
|
||||
comment = "no extra feedback"
|
||||
}
|
||||
return arguments, errors.New("human rejected this tool call; feedback: " + comment)
|
||||
}
|
||||
if len(d.EditedArguments) > 0 {
|
||||
return d.EditedArguments, nil
|
||||
}
|
||||
return arguments, nil
|
||||
}
|
||||
|
||||
func (h *AgentHandler) injectReactHITLInterceptor(ctx context.Context, cancelRun context.CancelCauseFunc, conversationID, assistantMessageID string, sendEventFunc func(eventType, message string, data interface{})) context.Context {
|
||||
return agent.WithToolCallInterceptor(ctx, func(c context.Context, toolName string, args map[string]interface{}, toolCallID string) (map[string]interface{}, error) {
|
||||
return h.interceptHITLForReactTool(c, cancelRun, conversationID, assistantMessageID, sendEventFunc, toolName, args, toolCallID)
|
||||
})
|
||||
}
|
||||
|
||||
type hitlConfigReq struct {
|
||||
ConversationID string `json:"conversationId" binding:"required"`
|
||||
HITLRequest
|
||||
}
|
||||
|
||||
func (h *AgentHandler) GetHITLConversationConfig(c *gin.Context) {
|
||||
conversationID := strings.TrimSpace(c.Param("conversationId"))
|
||||
if conversationID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "conversationId is required"})
|
||||
return
|
||||
}
|
||||
cfg, err := h.hitlManager.LoadConversationConfig(conversationID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if !hitlStoredConfigEffective(cfg) {
|
||||
if pendMode, ok := h.hitlManager.PendingHITLInterruptMode(conversationID); ok {
|
||||
cfg2 := *cfg
|
||||
cfg2.Enabled = true
|
||||
cfg2.Mode = normalizeHitlMode(pendMode)
|
||||
if cfg2.TimeoutSeconds <= 0 {
|
||||
cfg2.TimeoutSeconds = 300
|
||||
}
|
||||
cfg = &cfg2
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"conversationId": conversationID,
|
||||
"hitl": cfg,
|
||||
"hitlGlobalToolWhitelist": h.hitlConfigGlobalToolWhitelist(),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AgentHandler) UpsertHITLConversationConfig(c *gin.Context) {
|
||||
var req hitlConfigReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
req.Mode = normalizeHitlMode(req.Mode)
|
||||
if err := h.hitlManager.SaveConversationConfig(req.ConversationID, &req.HITLRequest); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.hitlWhitelistSaver != nil && len(req.SensitiveTools) > 0 {
|
||||
if err := h.hitlWhitelistSaver.MergeHitlToolWhitelistIntoConfig(req.SensitiveTools); err != nil {
|
||||
h.logger.Warn("HITL 会话配置已保存,但合并工具白名单到 config.yaml 失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "会话配置已保存,但写入 config.yaml 失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
h.hitlManager.ActivateConversation(req.ConversationID, h.hitlRequestWithMergedConfigWhitelist(&req.HITLRequest))
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
}
|
||||
|
||||
type mergeHitlGlobalWhitelistReq struct {
|
||||
SensitiveTools []string `json:"sensitiveTools"`
|
||||
}
|
||||
|
||||
// MergeHITLGlobalToolWhitelist 无会话 ID 时将侧栏提交的免审批工具合并进 config.yaml(与 PUT /hitl/config 中白名单落盘规则一致)。
|
||||
func (h *AgentHandler) MergeHITLGlobalToolWhitelist(c *gin.Context) {
|
||||
if h.hitlWhitelistSaver == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "HITL 配置持久化不可用"})
|
||||
return
|
||||
}
|
||||
var req mergeHitlGlobalWhitelistReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if len(req.SensitiveTools) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"ok": true,
|
||||
"hitlGlobalToolWhitelist": h.hitlConfigGlobalToolWhitelist(),
|
||||
"hitlGlobalWhitelistMerged": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := h.hitlWhitelistSaver.MergeHitlToolWhitelistIntoConfig(req.SensitiveTools); err != nil {
|
||||
h.logger.Warn("合并 HITL 工具白名单到 config.yaml 失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"ok": true,
|
||||
"hitlGlobalToolWhitelist": h.hitlConfigGlobalToolWhitelist(),
|
||||
"hitlGlobalWhitelistMerged": true,
|
||||
})
|
||||
}
|
||||
|
||||
func boolToInt(v bool) int {
|
||||
if v {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
@@ -53,25 +53,36 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
clientDisconnected := false
|
||||
// 与 sseKeepalive 共用:禁止并发写 ResponseWriter,否则会破坏 chunked 编码(ERR_INVALID_CHUNKED_ENCODING)。
|
||||
var sseWriteMu sync.Mutex
|
||||
var ssePublishConversationID string
|
||||
sendEvent := func(eventType, message string, data interface{}) {
|
||||
if clientDisconnected {
|
||||
return
|
||||
}
|
||||
// 用户主动停止时,Eino 可能仍会并发上报 eventType=="error"。
|
||||
// 为避免 UI 看到“取消错误 + cancelled 文案”两条回复,这里直接丢弃取消对应的 error。
|
||||
if eventType == "error" && baseCtx != nil && errors.Is(context.Cause(baseCtx), ErrTaskCancelled) {
|
||||
return
|
||||
}
|
||||
ev := StreamEvent{Type: eventType, Message: message, Data: data}
|
||||
b, errMarshal := json.Marshal(ev)
|
||||
if errMarshal != nil {
|
||||
b = []byte(`{"type":"error","message":"marshal failed"}`)
|
||||
}
|
||||
sseLine := make([]byte, 0, len(b)+8)
|
||||
sseLine = append(sseLine, []byte("data: ")...)
|
||||
sseLine = append(sseLine, b...)
|
||||
sseLine = append(sseLine, '\n', '\n')
|
||||
if ssePublishConversationID != "" && h.taskEventBus != nil {
|
||||
h.taskEventBus.Publish(ssePublishConversationID, sseLine)
|
||||
}
|
||||
if clientDisconnected {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
clientDisconnected = true
|
||||
return
|
||||
default:
|
||||
}
|
||||
ev := StreamEvent{Type: eventType, Message: message, Data: data}
|
||||
b, _ := json.Marshal(ev)
|
||||
sseWriteMu.Lock()
|
||||
_, err := fmt.Fprintf(c.Writer, "data: %s\n\n", b)
|
||||
_, err := c.Writer.Write(sseLine)
|
||||
if err != nil {
|
||||
sseWriteMu.Unlock()
|
||||
clientDisconnected = true
|
||||
@@ -95,6 +106,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
sendEvent("done", "", nil)
|
||||
return
|
||||
}
|
||||
ssePublishConversationID = prep.ConversationID
|
||||
if prep.CreatedNew {
|
||||
sendEvent("conversation", "会话已创建", map[string]interface{}{
|
||||
"conversationId": prep.ConversationID,
|
||||
@@ -103,6 +115,10 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
|
||||
conversationID := prep.ConversationID
|
||||
assistantMessageID := prep.AssistantMessageID
|
||||
h.activateHITLForConversation(conversationID, req.Hitl)
|
||||
if h.hitlManager != nil {
|
||||
defer h.hitlManager.DeactivateConversation(conversationID)
|
||||
}
|
||||
|
||||
if prep.UserMessageID != "" {
|
||||
sendEvent("message_saved", "", map[string]interface{}{
|
||||
@@ -111,12 +127,14 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
progressCallback := h.createProgressCallback(conversationID, assistantMessageID, sendEvent)
|
||||
|
||||
baseCtx, cancelWithCause := context.WithCancelCause(context.Background())
|
||||
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
defer timeoutCancel()
|
||||
defer cancelWithCause(nil)
|
||||
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||
taskCtx = multiagent.WithHITLToolInterceptor(taskCtx, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
|
||||
})
|
||||
|
||||
if _, err := h.tasks.StartTask(conversationID, req.Message, cancelWithCause); err != nil {
|
||||
var errorMsg string
|
||||
@@ -181,6 +199,23 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if errors.Is(runErr, context.DeadlineExceeded) || errors.Is(context.Cause(taskCtx), context.DeadlineExceeded) {
|
||||
taskStatus = "timeout"
|
||||
h.tasks.UpdateTaskStatus(conversationID, taskStatus)
|
||||
timeoutMsg := "任务执行超时,已自动终止。"
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", timeoutMsg, assistantMessageID)
|
||||
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil)
|
||||
}
|
||||
sendEvent("error", timeoutMsg, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"messageId": assistantMessageID,
|
||||
"errorType": "timeout",
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Error("Eino DeepAgent 执行失败", zap.Error(runErr))
|
||||
taskStatus = "failed"
|
||||
h.tasks.UpdateTaskStatus(conversationID, taskStatus)
|
||||
@@ -251,9 +286,20 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
||||
c.JSON(status, gin.H{"error": msg})
|
||||
return
|
||||
}
|
||||
h.activateHITLForConversation(prep.ConversationID, req.Hitl)
|
||||
if h.hitlManager != nil {
|
||||
defer h.hitlManager.DeactivateConversation(prep.ConversationID)
|
||||
}
|
||||
|
||||
baseCtx, cancelWithCause := context.WithCancelCause(c.Request.Context())
|
||||
defer cancelWithCause(nil)
|
||||
progressCallback := h.createProgressCallback(baseCtx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, nil)
|
||||
baseCtx = multiagent.WithHITLToolInterceptor(baseCtx, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||
return h.interceptHITLForEinoTool(ctx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, nil, toolName, arguments)
|
||||
})
|
||||
|
||||
result, runErr := multiagent.RunDeepAgent(
|
||||
c.Request.Context(),
|
||||
baseCtx,
|
||||
h.config,
|
||||
&h.config.MultiAgent,
|
||||
h.agent,
|
||||
@@ -262,7 +308,7 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
||||
prep.FinalMessage,
|
||||
prep.History,
|
||||
prep.RoleTools,
|
||||
nil,
|
||||
progressCallback,
|
||||
h.agentsMarkdownDir,
|
||||
strings.TrimSpace(req.Orchestration),
|
||||
)
|
||||
|
||||
@@ -0,0 +1,116 @@
|
||||
package handler
|
||||
|
||||
import "sync"
|
||||
|
||||
// TaskEventBus 将主 SSE 连接上的事件镜像给后订阅的客户端(例如刷新页面后、HITL 审批通过需继续收事件)。
|
||||
// 每个 payload 为完整 SSE 行: "data: {...}\n\n"
|
||||
type TaskEventBus struct {
|
||||
mu sync.RWMutex
|
||||
subs map[string]map[*taskEventSub]struct{}
|
||||
}
|
||||
|
||||
type taskEventSub struct {
|
||||
mu sync.Mutex
|
||||
ch chan []byte
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (s *taskEventSub) sendNonBlocking(line []byte) bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.closed {
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case s.ch <- line:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *taskEventSub) closeOnce() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.closed {
|
||||
return
|
||||
}
|
||||
s.closed = true
|
||||
close(s.ch)
|
||||
}
|
||||
|
||||
func NewTaskEventBus() *TaskEventBus {
|
||||
return &TaskEventBus{
|
||||
subs: make(map[string]map[*taskEventSub]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe 注册订阅;cancel 时需调用 Unsubscribe。
|
||||
func (b *TaskEventBus) Subscribe(conversationID string) (sub *taskEventSub, ch <-chan []byte) {
|
||||
chBuf := make(chan []byte, 256)
|
||||
sub = &taskEventSub{ch: chBuf}
|
||||
b.mu.Lock()
|
||||
if b.subs[conversationID] == nil {
|
||||
b.subs[conversationID] = make(map[*taskEventSub]struct{})
|
||||
}
|
||||
b.subs[conversationID][sub] = struct{}{}
|
||||
b.mu.Unlock()
|
||||
return sub, chBuf
|
||||
}
|
||||
|
||||
func (b *TaskEventBus) Unsubscribe(conversationID string, sub *taskEventSub) {
|
||||
if sub == nil {
|
||||
return
|
||||
}
|
||||
b.mu.Lock()
|
||||
m, ok := b.subs[conversationID]
|
||||
if !ok {
|
||||
b.mu.Unlock()
|
||||
return
|
||||
}
|
||||
delete(m, sub)
|
||||
if len(m) == 0 {
|
||||
delete(b.subs, conversationID)
|
||||
}
|
||||
b.mu.Unlock()
|
||||
sub.closeOnce()
|
||||
}
|
||||
|
||||
// Publish 非阻塞投递;慢消费者丢帧(HITL 场景以最新状态为准,丢帧可接受)。
|
||||
func (b *TaskEventBus) Publish(conversationID string, line []byte) {
|
||||
if b == nil || conversationID == "" || len(line) == 0 {
|
||||
return
|
||||
}
|
||||
b.mu.RLock()
|
||||
m := b.subs[conversationID]
|
||||
subs := make([]*taskEventSub, 0, len(m))
|
||||
for s := range m {
|
||||
subs = append(subs, s)
|
||||
}
|
||||
b.mu.RUnlock()
|
||||
|
||||
cp := append([]byte(nil), line...)
|
||||
for _, s := range subs {
|
||||
s.sendNonBlocking(cp)
|
||||
}
|
||||
}
|
||||
|
||||
// CloseConversation 任务结束时关闭该会话所有订阅 channel。
|
||||
func (b *TaskEventBus) CloseConversation(conversationID string) {
|
||||
if b == nil || conversationID == "" {
|
||||
return
|
||||
}
|
||||
b.mu.Lock()
|
||||
m := b.subs[conversationID]
|
||||
delete(b.subs, conversationID)
|
||||
b.mu.Unlock()
|
||||
for sub := range m {
|
||||
sub.closeOnce()
|
||||
}
|
||||
}
|
||||
@@ -35,11 +35,12 @@ type CompletedTask struct {
|
||||
|
||||
// AgentTaskManager 管理正在运行的Agent任务
|
||||
type AgentTaskManager struct {
|
||||
mu sync.RWMutex
|
||||
tasks map[string]*AgentTask
|
||||
completedTasks []*CompletedTask // 最近完成的任务历史
|
||||
maxHistorySize int // 最大历史记录数
|
||||
historyRetention time.Duration // 历史记录保留时间
|
||||
mu sync.RWMutex
|
||||
tasks map[string]*AgentTask
|
||||
completedTasks []*CompletedTask // 最近完成的任务历史
|
||||
maxHistorySize int // 最大历史记录数
|
||||
historyRetention time.Duration // 历史记录保留时间
|
||||
eventBus *TaskEventBus // 可选:任务结束时关闭镜像 SSE 订阅
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -56,13 +57,27 @@ func NewAgentTaskManager() *AgentTaskManager {
|
||||
m := &AgentTaskManager{
|
||||
tasks: make(map[string]*AgentTask),
|
||||
completedTasks: make([]*CompletedTask, 0),
|
||||
maxHistorySize: 50, // 最多保留50条历史记录
|
||||
historyRetention: 24 * time.Hour, // 保留24小时
|
||||
maxHistorySize: 50, // 最多保留50条历史记录
|
||||
historyRetention: 24 * time.Hour, // 保留24小时
|
||||
}
|
||||
go m.runStuckCancellingCleanup()
|
||||
return m
|
||||
}
|
||||
|
||||
// SetTaskEventBus 设置任务事件总线(与 AgentHandler 共用同一实例)。
|
||||
func (m *AgentTaskManager) SetTaskEventBus(b *TaskEventBus) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.eventBus = b
|
||||
}
|
||||
|
||||
// GetTask 返回运行中任务(无则 nil)。
|
||||
func (m *AgentTaskManager) GetTask(conversationID string) *AgentTask {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.tasks[conversationID]
|
||||
}
|
||||
|
||||
// runStuckCancellingCleanup 定期将长时间处于「取消中」的任务强制结束,避免卡住无法发新消息
|
||||
func (m *AgentTaskManager) runStuckCancellingCleanup() {
|
||||
ticker := time.NewTicker(cleanupInterval)
|
||||
@@ -172,10 +187,9 @@ func (m *AgentTaskManager) UpdateTaskStatus(conversationID string, status string
|
||||
// FinishTask 完成任务并从管理器中移除
|
||||
func (m *AgentTaskManager) FinishTask(conversationID string, finalStatus string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
task, exists := m.tasks[conversationID]
|
||||
if !exists {
|
||||
m.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -187,26 +201,31 @@ func (m *AgentTaskManager) FinishTask(conversationID string, finalStatus string)
|
||||
completedTask := &CompletedTask{
|
||||
ConversationID: task.ConversationID,
|
||||
Message: task.Message,
|
||||
StartedAt: task.StartedAt,
|
||||
CompletedAt: time.Now(),
|
||||
Status: finalStatus,
|
||||
StartedAt: task.StartedAt,
|
||||
CompletedAt: time.Now(),
|
||||
Status: finalStatus,
|
||||
}
|
||||
|
||||
|
||||
// 添加到历史记录
|
||||
m.completedTasks = append(m.completedTasks, completedTask)
|
||||
|
||||
|
||||
// 清理过期和过多的历史记录
|
||||
m.cleanupHistory()
|
||||
|
||||
// 从运行任务中移除
|
||||
delete(m.tasks, conversationID)
|
||||
bus := m.eventBus
|
||||
m.mu.Unlock()
|
||||
if bus != nil {
|
||||
bus.CloseConversation(conversationID)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupHistory 清理过期的历史记录
|
||||
func (m *AgentTaskManager) cleanupHistory() {
|
||||
now := time.Now()
|
||||
cutoffTime := now.Add(-m.historyRetention)
|
||||
|
||||
|
||||
// 过滤掉过期的记录
|
||||
validTasks := make([]*CompletedTask, 0, len(m.completedTasks))
|
||||
for _, task := range m.completedTasks {
|
||||
@@ -214,7 +233,7 @@ func (m *AgentTaskManager) cleanupHistory() {
|
||||
validTasks = append(validTasks, task)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 如果仍然超过最大数量,只保留最新的
|
||||
if len(validTasks) > m.maxHistorySize {
|
||||
// 按完成时间排序,保留最新的
|
||||
@@ -222,7 +241,7 @@ func (m *AgentTaskManager) cleanupHistory() {
|
||||
start := len(validTasks) - m.maxHistorySize
|
||||
validTasks = validTasks[start:]
|
||||
}
|
||||
|
||||
|
||||
m.completedTasks = validTasks
|
||||
}
|
||||
|
||||
@@ -247,30 +266,30 @@ func (m *AgentTaskManager) GetActiveTasks() []*AgentTask {
|
||||
func (m *AgentTaskManager) GetCompletedTasks() []*CompletedTask {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
|
||||
// 清理过期记录(只读锁,不影响其他操作)
|
||||
// 注意:这里不能直接调用cleanupHistory,因为需要写锁
|
||||
// 所以返回时过滤过期记录
|
||||
now := time.Now()
|
||||
cutoffTime := now.Add(-m.historyRetention)
|
||||
|
||||
|
||||
result := make([]*CompletedTask, 0, len(m.completedTasks))
|
||||
for _, task := range m.completedTasks {
|
||||
if task.CompletedAt.After(cutoffTime) {
|
||||
result = append(result, task)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 按完成时间倒序排序(最新的在前)
|
||||
// 由于是追加的,最新的在最后,需要反转
|
||||
for i, j := 0, len(result)-1; i < j; i, j = i+1, j-1 {
|
||||
result[i], result[j] = result[j], result[i]
|
||||
}
|
||||
|
||||
|
||||
// 限制返回数量
|
||||
if len(result) > m.maxHistorySize {
|
||||
result = result[:m.maxHistorySize]
|
||||
}
|
||||
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user