diff --git a/internal/handler/agent.go b/internal/handler/agent.go index 1222d69e..220ea95b 100644 --- a/internal/handler/agent.go +++ b/internal/handler/agent.go @@ -77,6 +77,13 @@ type responsePlanAgg struct { b strings.Builder } +// thinkingBuf aggregates thinking_stream_* / reasoning_chain_stream_* before flush to process_details. +type thinkingBuf struct { + b strings.Builder + meta map[string]interface{} + persistAs string // "thinking" | "reasoning_chain" +} + func normalizeProcessDetailText(s string) string { s = strings.ReplaceAll(s, "\r\n", "\n") s = strings.ReplaceAll(s, "\r", "\n") @@ -179,6 +186,8 @@ type AgentHandler struct { batchCronParser cron.Parser // hitlWhitelistSaver 侧栏「应用」HITL 时将会话增量白名单合并写入 config.yaml(可选) hitlWhitelistSaver HitlToolWhitelistSaver + hitlStrategySaver HitlAuditStrategySaver + auditLLM *openai.Client audit *audit.Service } @@ -218,9 +227,10 @@ func (h *AgentHandler) cancelActiveMCPToolForConversation(conversationID string) } } -// HitlToolWhitelistSaver 合并 HITL 免审批工具到全局配置并落盘 +// HitlToolWhitelistSaver 合并/设置 HITL 免审批工具到全局配置并落盘 type HitlToolWhitelistSaver interface { MergeHitlToolWhitelistIntoConfig(add []string) error + SetHitlToolWhitelist(tools []string) error } // NewAgentHandler 创建新的Agent处理器 @@ -236,6 +246,11 @@ func NewAgentHandler(agent *agent.Agent, db *database.DB, cfg *config.Config, lo bus := NewTaskEventBus() tm := NewAgentTaskManager() tm.SetTaskEventBus(bus) + llmHTTP := &http.Client{Timeout: 2 * time.Minute} + var llmCfg *config.OpenAIConfig + if cfg != nil { + llmCfg = &cfg.OpenAI + } handler := &AgentHandler{ agent: agent, db: db, @@ -246,6 +261,7 @@ func NewAgentHandler(agent *agent.Agent, db *database.DB, cfg *config.Config, lo config: cfg, hitlManager: NewHITLManager(db, logger), batchCronParser: cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor), + auditLLM: openai.NewClient(llmCfg, llmHTTP, logger), } tm.SetToolCanceler(handler.cancelActiveMCPToolForConversation) if err := handler.hitlManager.EnsureSchema(); err != nil { @@ -320,6 +336,7 @@ func chatReasoningToClientIntent(r *ChatReasoningRequest) *reasoning.ClientInten type HITLRequest struct { Enabled bool `json:"enabled"` Mode string `json:"mode,omitempty"` + Reviewer string `json:"reviewer,omitempty"` // human | audit_agent SensitiveTools []string `json:"sensitiveTools,omitempty"` TimeoutSeconds int `json:"timeoutSeconds,omitempty"` } @@ -849,11 +866,6 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun // thinking_stream_*(ReAct 等助手正文流)与 reasoning_chain_stream_*(Eino ReasoningContent): // 不逐条落库,按 streamId 聚合,flush 时分别落 thinking / reasoning_chain。 - type thinkingBuf struct { - b strings.Builder - meta map[string]interface{} - persistAs string // "thinking" | "reasoning_chain" - } thinkingStreams := make(map[string]*thinkingBuf) // streamId -> buf flushedThinking := make(map[string]bool) // streamId -> flushed seenToolCallSigs := make(map[string]string) // toolCallId -> payload signature @@ -866,6 +878,12 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun // response_start + response_delta:前端时间线显示为「📝 规划中」(monitor.js),不落逐条 delta; // 聚合为一条 planning 写入 process_details,刷新后与线上一致。 var respPlan responsePlanAgg + if assistantMessageID != "" { + h.tasks.SetHitlAssistantMessageID(conversationID, assistantMessageID) + } + syncHitlCognition := func() { + h.syncHitlCognitionFromProgress(conversationID, assistantMessageID, thinkingStreams, &respPlan) + } flushResponsePlan := func() { if assistantMessageID == "" { return @@ -885,6 +903,7 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "planning", content, data); err != nil { h.logger.Warn("保存过程详情失败", zap.Error(err), zap.String("eventType", "planning")) } + syncHitlCognition() respPlan.meta = nil respPlan.b.Reset() } @@ -921,6 +940,7 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun } flushedThinking[sid] = true } + syncHitlCognition() } return func(eventType, message string, data interface{}) { @@ -981,6 +1001,25 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun } } + if eventType == "tool_result" { + if dataMap, ok := data.(map[string]interface{}); ok { + toolName, _ := dataMap["toolName"].(string) + toolCallID, _ := dataMap["toolCallId"].(string) + success := true + if v, ok := dataMap["success"].(bool); ok { + success = v + } + resultText := "" + if r, ok := dataMap["result"].(string); ok { + resultText = r + } + if strings.TrimSpace(resultText) == "" { + resultText = message + } + h.recordHitlToolExecutionResult(conversationID, toolCallID, toolName, success, resultText) + } + } + // 处理知识检索日志记录 if eventType == "tool_result" && h.knowledgeManager != nil { if dataMap, ok := data.(map[string]interface{}); ok { @@ -1188,6 +1227,7 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun respPlan.meta[k] = v } } + syncHitlCognition() return } if eventType == "response" { @@ -1257,6 +1297,7 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun } } } + syncHitlCognition() return } diff --git a/internal/handler/config.go b/internal/handler/config.go index 3dd4d804..761fae57 100644 --- a/internal/handler/config.go +++ b/internal/handler/config.go @@ -1766,6 +1766,20 @@ func mergeHitlToolWhitelistSlice(existing, add []string) []string { return out } +// SetHitlToolWhitelist 将全局免审批工具白名单整表写入 config.yaml(替换,非合并)。 +func (h *ConfigHandler) SetHitlToolWhitelist(tools []string) error { + h.mu.Lock() + defer h.mu.Unlock() + h.config.Hitl.ToolWhitelist = mergeHitlToolWhitelistSlice(nil, tools) + if err := h.saveConfig(); err != nil { + return err + } + h.logger.Info("HITL 全局工具白名单已写入配置文件", + zap.Int("count", len(h.config.Hitl.ToolWhitelist)), + ) + return nil +} + // MergeHitlToolWhitelistIntoConfig 将会话侧栏提交的免审批工具名合并进内存配置并写入 config.yaml(与全局白名单去重规则一致:小写键、保留首次出现的原始大小写)。 func (h *ConfigHandler) MergeHitlToolWhitelistIntoConfig(add []string) error { h.mu.Lock() @@ -1786,6 +1800,21 @@ func updateHitlConfig(doc *yaml.Node, cfg config.HitlConfig) { hitlNode := ensureMap(root, "hitl") // flow 样式 [a, b, c] 单行展示,工具多时比块序列省行数 setFlowStringSliceInMap(hitlNode, "tool_whitelist", cfg.ToolWhitelist) + setStringInMap(hitlNode, "audit_agent_prompt", cfg.AuditAgentPrompt) + setStringInMap(hitlNode, "audit_agent_prompt_review_edit", cfg.AuditAgentPromptReviewEdit) +} + +// UpdateHitlAuditAgentStrategy 更新审批/审查编辑两套审计 Agent 提示词并写入 config.yaml。 +func (h *ConfigHandler) UpdateHitlAuditAgentStrategy(approvalPrompt, reviewEditPrompt string) error { + h.mu.Lock() + defer h.mu.Unlock() + h.config.Hitl.AuditAgentPrompt = strings.TrimSpace(approvalPrompt) + h.config.Hitl.AuditAgentPromptReviewEdit = strings.TrimSpace(reviewEditPrompt) + if err := h.saveConfig(); err != nil { + return err + } + h.logger.Info("HITL 审计 Agent 提示词已写入配置文件") + return nil } func updateRobotsConfig(doc *yaml.Node, cfg config.RobotsConfig) { diff --git a/internal/handler/hitl.go b/internal/handler/hitl.go index a6759639..274ee07a 100644 --- a/internal/handler/hitl.go +++ b/internal/handler/hitl.go @@ -23,6 +23,7 @@ import ( type hitlRuntimeConfig struct { Enabled bool Mode string + Reviewer string SensitiveTools map[string]struct{} Timeout time.Duration } @@ -49,6 +50,8 @@ type HITLManager struct { mu sync.RWMutex runtime map[string]hitlRuntimeConfig pending map[string]*pendingInterrupt + // approvedExec 审批通过、待回写 tool_result 的队列(按会话 FIFO) + approvedExec map[string][]hitlApprovedExecTrack } func NewHITLManager(db *database.DB, logger *zap.Logger) *HITLManager { @@ -90,6 +93,7 @@ CREATE TABLE IF NOT EXISTS hitl_conversation_configs ( if err != nil { return err } + m.migrateHitlSchemaColumns() // On startup, cancel all orphaned pending interrupts from previous process. // Their in-memory channels are gone, so they can never be resolved. @@ -141,6 +145,7 @@ func (m *HITLManager) ActivateConversation(conversationID string, req *HITLReque m.runtime[conversationID] = hitlRuntimeConfig{ Enabled: true, Mode: normalizeHitlMode(req.Mode), + Reviewer: normalizeHitlReviewer(req.Reviewer), SensitiveTools: tools, Timeout: timeout, } @@ -362,22 +367,22 @@ func (m *HITLManager) SaveConversationConfig(conversationID string, req *HITLReq timeout = 0 } _, err := m.db.Exec(`INSERT INTO hitl_conversation_configs - (conversation_id, enabled, mode, sensitive_tools, timeout_seconds, updated_at) - VALUES (?, ?, ?, ?, ?, ?) + (conversation_id, enabled, mode, reviewer, sensitive_tools, timeout_seconds, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?) ON CONFLICT(conversation_id) DO UPDATE SET - enabled=excluded.enabled, mode=excluded.mode, sensitive_tools=excluded.sensitive_tools, timeout_seconds=excluded.timeout_seconds, updated_at=excluded.updated_at`, - conversationID, boolToInt(req.Enabled), mode, string(tools), timeout, time.Now()) + enabled=excluded.enabled, mode=excluded.mode, reviewer=excluded.reviewer, sensitive_tools=excluded.sensitive_tools, timeout_seconds=excluded.timeout_seconds, updated_at=excluded.updated_at`, + conversationID, boolToInt(req.Enabled), mode, normalizeHitlReviewer(req.Reviewer), string(tools), timeout, time.Now()) return err } func (m *HITLManager) LoadConversationConfig(conversationID string) (*HITLRequest, error) { var enabledInt int - var mode, toolsJSON string + var mode, reviewer, toolsJSON string var timeout int - err := m.db.QueryRow(`SELECT enabled, mode, sensitive_tools, timeout_seconds FROM hitl_conversation_configs WHERE conversation_id = ?`, conversationID). - Scan(&enabledInt, &mode, &toolsJSON, &timeout) + err := m.db.QueryRow(`SELECT enabled, mode, COALESCE(reviewer,'human'), sensitive_tools, timeout_seconds FROM hitl_conversation_configs WHERE conversation_id = ?`, conversationID). + Scan(&enabledInt, &mode, &reviewer, &toolsJSON, &timeout) if errors.Is(err, sql.ErrNoRows) { - return &HITLRequest{Enabled: false, Mode: "off", SensitiveTools: []string{}, TimeoutSeconds: 0}, nil + return &HITLRequest{Enabled: false, Mode: "off", Reviewer: "human", SensitiveTools: []string{}, TimeoutSeconds: 0}, nil } if err != nil { return nil, err @@ -390,6 +395,7 @@ func (m *HITLManager) LoadConversationConfig(conversationID string) (*HITLReques return &HITLRequest{ Enabled: enabledInt == 1, Mode: mode, + Reviewer: normalizeHitlReviewer(reviewer), SensitiveTools: tools, TimeoutSeconds: timeout, }, nil @@ -413,15 +419,15 @@ func (m *HITLManager) waitDecision(ctx context.Context, p *pendingInterrupt, tim if p.Mode != "review_edit" && len(d.EditedArguments) > 0 { d.EditedArguments = nil } - _, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='decided', decision=?, decision_comment=?, decided_at=? WHERE id=?`, + _, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='decided', decision=?, decision_comment=?, decided_at=?, decided_by='human' WHERE id=?`, d.Decision, d.Comment, time.Now(), p.InterruptID) return d, nil case <-timeoutCh: - _, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='timeout', decision='approve', decision_comment='timeout auto approve', decided_at=? WHERE id=?`, + _, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='timeout', decision='approve', decision_comment='timeout auto approve', decided_at=?, decided_by='system' WHERE id=?`, time.Now(), p.InterruptID) return hitlDecision{Decision: "approve", Comment: "timeout auto approve"}, nil case <-ctx.Done(): - _, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='cancelled', decision='reject', decision_comment='task cancelled', decided_at=? WHERE id=?`, + _, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='cancelled', decision='reject', decision_comment='task cancelled', decided_at=?, decided_by='system' WHERE id=?`, time.Now(), p.InterruptID) return hitlDecision{Decision: "reject", Comment: "task cancelled"}, ctx.Err() } @@ -445,12 +451,57 @@ func (h *AgentHandler) waitHITLApproval(runCtx context.Context, cancelRun contex if !need { return nil, nil } + h.enrichHitlApprovalPayload(conversationID, assistantMessageID, payload) payloadRaw, _ := json.Marshal(payload) p, err := h.hitlManager.CreatePendingInterrupt(conversationID, assistantMessageID, cfg.Mode, toolName, toolCallID, string(payloadRaw)) if err != nil { h.logger.Warn("创建 HITL 中断失败", zap.Error(err)) return nil, err } + + if cfg.Reviewer == "audit_agent" { + ad := h.auditAgentReview(runCtx, cfg.Mode, toolName, payload) + now := time.Now() + _, _ = h.db.Exec(`UPDATE hitl_interrupts SET status='decided', decision=?, decision_comment=?, decided_at=?, decided_by='audit_agent' WHERE id=?`, + ad.Decision, ad.Comment, now, p.InterruptID) + if sendEventFunc != nil { + sendEventFunc("hitl_audit_agent", "审计 Agent 已裁决", map[string]interface{}{ + "conversationId": conversationID, + "interruptId": p.InterruptID, + "toolName": toolName, + "mode": cfg.Mode, + "decision": ad.Decision, + "comment": ad.Comment, + "editedArgs": ad.EditedArguments, + "decidedBy": "audit_agent", + }) + } + if ad.Decision == "reject" { + if sendEventFunc != nil { + sendEventFunc("hitl_rejected", "审计 Agent 拒绝本次工具调用", map[string]interface{}{ + "conversationId": conversationID, + "interruptId": p.InterruptID, + "toolName": toolName, + "comment": ad.Comment, + "decidedBy": "audit_agent", + }) + } + return &ad, nil + } + if sendEventFunc != nil { + sendEventFunc("hitl_resumed", "审计 Agent 已通过,继续执行", map[string]interface{}{ + "conversationId": conversationID, + "interruptId": p.InterruptID, + "toolName": toolName, + "comment": ad.Comment, + "editedArgs": ad.EditedArguments, + "decidedBy": "audit_agent", + }) + } + h.hitlManager.TrackApprovedHitlExecution(p.InterruptID, conversationID, toolName, toolCallID) + return &ad, nil + } + if sendEventFunc != nil { sendEventFunc("hitl_interrupt", "命中人机协同审批", map[string]interface{}{ "conversationId": conversationID, @@ -498,6 +549,7 @@ func (h *AgentHandler) waitHITLApproval(runCtx context.Context, cancelRun contex "editedArgs": d.EditedArguments, }) } + h.hitlManager.TrackApprovedHitlExecution(p.InterruptID, conversationID, toolName, toolCallID) return &d, nil } @@ -527,11 +579,6 @@ func (h *AgentHandler) handleHITLToolCall(runCtx context.Context, cancelRun cont } func (h *AgentHandler) ListHITLPending(c *gin.Context) { - conversationID := strings.TrimSpace(c.Query("conversationId")) - status := strings.TrimSpace(c.Query("status")) - if status == "" { - status = "pending" - } page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) if page < 1 { page = 1 @@ -539,15 +586,12 @@ func (h *AgentHandler) ListHITLPending(c *gin.Context) { pageSize, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20")) pageSize = int(math.Max(1, math.Min(float64(pageSize), 200))) offset := (page - 1) * pageSize - q := `SELECT id, conversation_id, message_id, mode, tool_name, tool_call_id, payload, status, decision, decision_comment, created_at, decided_at FROM hitl_interrupts WHERE 1=1` - args := []interface{}{} - if conversationID != "" { - q += " AND conversation_id = ?" - args = append(args, conversationID) - } - if status != "all" { - q += " AND status = ?" - args = append(args, status) + q, args := h.buildHitlListQuery(false) + q, args = h.appendHitlListFilters(q, args, c) + total, err := h.countHitlQuery(q, args) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return } q += " ORDER BY created_at DESC LIMIT ? OFFSET ?" args = append(args, pageSize, offset) @@ -557,41 +601,12 @@ func (h *AgentHandler) ListHITLPending(c *gin.Context) { return } defer rows.Close() - items := make([]map[string]interface{}, 0) - for rows.Next() { - var id, cid, mode, toolName, toolCallID, payload, rowStatus string - var messageID sql.NullString - var decision, comment sql.NullString - var createdAt time.Time - var decidedAt sql.NullTime - if err := rows.Scan(&id, &cid, &messageID, &mode, &toolName, &toolCallID, &payload, &rowStatus, &decision, &comment, &createdAt, &decidedAt); err != nil { - continue - } - msgID := "" - if messageID.Valid { - msgID = messageID.String - } - items = append(items, map[string]interface{}{ - "id": id, - "conversationId": cid, - "messageId": msgID, - "mode": mode, - "toolName": toolName, - "toolCallId": toolCallID, - "payload": payload, - "status": rowStatus, - "decision": decision.String, - "comment": comment.String, - "createdAt": createdAt, - "decidedAt": func() interface{} { - if decidedAt.Valid { - return decidedAt.Time - } - return nil - }(), - }) + items, err := h.scanHitlInterruptRows(rows) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return } - c.JSON(http.StatusOK, gin.H{"items": items, "page": page, "pageSize": pageSize}) + c.JSON(http.StatusOK, gin.H{"items": items, "page": page, "pageSize": pageSize, "total": total}) } type hitlDecisionReq struct { @@ -636,7 +651,7 @@ func (h *AgentHandler) DismissHITLInterrupt(c *gin.Context) { return } res, err := h.db.Exec(`UPDATE hitl_interrupts SET status='cancelled', decision='reject', - decision_comment='dismissed by user', decided_at=CURRENT_TIMESTAMP + decision_comment='dismissed by user', decided_at=CURRENT_TIMESTAMP, decided_by='human' WHERE id=? AND status='pending'`, req.InterruptID) if err != nil { c.JSON(500, gin.H{"error": err.Error()}) @@ -732,6 +747,7 @@ func (h *AgentHandler) UpsertHITLConversationConfig(c *gin.Context) { return } req.Mode = normalizeHitlMode(req.Mode) + req.Reviewer = normalizeHitlReviewer(req.Reviewer) if err := h.hitlManager.SaveConversationConfig(req.ConversationID, &req.HITLRequest); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -753,6 +769,44 @@ type mergeHitlGlobalWhitelistReq struct { SensitiveTools []string `json:"sensitiveTools"` } +type setHitlGlobalWhitelistReq struct { + ToolWhitelist []string `json:"toolWhitelist"` +} + +// GetHITLGlobalToolWhitelist 返回 config.yaml 中的全局免审批工具白名单。 +func (h *AgentHandler) GetHITLGlobalToolWhitelist(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "toolWhitelist": h.hitlConfigGlobalToolWhitelist(), + }) +} + +// SetHITLGlobalToolWhitelist 整表替换 config.yaml 中的全局免审批工具白名单。 +func (h *AgentHandler) SetHITLGlobalToolWhitelist(c *gin.Context) { + if h.hitlWhitelistSaver == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "HITL 配置持久化不可用"}) + return + } + var req setHitlGlobalWhitelistReq + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if err := h.hitlWhitelistSaver.SetHitlToolWhitelist(req.ToolWhitelist); err != nil { + h.logger.Warn("写入 HITL 工具白名单到 config.yaml 失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if h.audit != nil { + h.audit.RecordOK(c, "hitl", "tool_whitelist_update", "HITL 全局白名单更新", "hitl_config", "tool_whitelist", nil) + } + c.JSON(http.StatusOK, gin.H{ + "ok": true, + "toolWhitelist": h.hitlConfigGlobalToolWhitelist(), + "hitlGlobalToolWhitelist": h.hitlConfigGlobalToolWhitelist(), + "hitlGlobalWhitelistMerged": false, + }) +} + // MergeHITLGlobalToolWhitelist 无会话 ID 时将侧栏提交的免审批工具合并进 config.yaml(与 PUT /hitl/config 中白名单落盘规则一致)。 func (h *AgentHandler) MergeHITLGlobalToolWhitelist(c *gin.Context) { if h.hitlWhitelistSaver == nil { diff --git a/internal/handler/hitl_audit_agent.go b/internal/handler/hitl_audit_agent.go new file mode 100644 index 00000000..c06bd312 --- /dev/null +++ b/internal/handler/hitl_audit_agent.go @@ -0,0 +1,357 @@ +package handler + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + "time" + + "cyberstrike-ai/internal/config" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// auditAgentReview 在 reviewer=audit_agent 时由 LLM 代行审批。 +// 白名单工具在 shouldInterrupt 阶段已跳过,到达此处的一律需要裁决。 +func (h *AgentHandler) auditAgentReview(ctx context.Context, hitlMode, toolName string, payload map[string]interface{}) hitlDecision { + if h == nil { + return hitlDecision{Decision: "reject", Comment: "audit agent: handler unavailable"} + } + mode := normalizeHitlMode(hitlMode) + prompt := config.DefaultHitlAuditAgentPrompt() + if h.config != nil { + prompt = h.config.Hitl.EffectiveAuditAgentPromptForMode(mode) + } + if h.auditLLM == nil { + return hitlDecision{Decision: "reject", Comment: "audit agent: LLM 未配置"} + } + if ctx == nil { + ctx = context.Background() + } + callCtx, cancel := context.WithTimeout(ctx, 90*time.Second) + defer cancel() + + userContent := buildAuditAgentReviewInput(mode, toolName, payload) + requestBody := map[string]interface{}{ + "model": h.auditLLMModel(), + "messages": []map[string]interface{}{ + {"role": "system", "content": prompt}, + {"role": "user", "content": userContent}, + }, + "temperature": 0.1, + "max_completion_tokens": 1024, + // 审计裁决需要结构化 JSON;关闭 thinking 避免 Qwen 等把正文放进 reasoning_content 导致解析失败。 + "thinking": map[string]interface{}{"type": "disabled"}, + } + + var apiResponse struct { + Choices []struct { + Message struct { + Content string `json:"content"` + ReasoningContent string `json:"reasoning_content"` + } `json:"message"` + } `json:"choices"` + } + if err := h.auditLLM.ChatCompletion(callCtx, requestBody, &apiResponse); err != nil { + h.logger.Warn("审计 Agent LLM 调用失败", zap.Error(err), zap.String("tool", toolName)) + return hitlDecision{ + Decision: "reject", + Comment: "audit agent: LLM 调用失败,保守拒绝", + } + } + if len(apiResponse.Choices) == 0 { + return hitlDecision{Decision: "reject", Comment: "audit agent: LLM 无有效响应,保守拒绝"} + } + msg := apiResponse.Choices[0].Message + raw := strings.TrimSpace(msg.Content) + if raw == "" { + raw = strings.TrimSpace(msg.ReasoningContent) + } + dec, err := parseAuditAgentLLMContent(raw) + if err != nil { + snippet := raw + if len(snippet) > 240 { + snippet = snippet[:240] + "..." + } + h.logger.Warn("审计 Agent 响应解析失败", + zap.Error(err), + zap.String("tool", toolName), + zap.String("mode", mode), + zap.String("snippet", snippet), + ) + return hitlDecision{Decision: "reject", Comment: "audit agent: 响应无法解析,保守拒绝"} + } + if mode != "review_edit" && len(dec.EditedArguments) > 0 { + h.logger.Warn("审计 Agent 在审批模式下返回 editedArguments,已忽略", + zap.String("tool", toolName), + ) + dec.EditedArguments = nil + } + if dec.Comment == "" { + dec.Comment = "audit agent: " + dec.Decision + } else if !strings.HasPrefix(strings.ToLower(dec.Comment), "audit agent") { + dec.Comment = "audit agent: " + dec.Comment + } + return dec +} + +func (h *AgentHandler) auditLLMModel() string { + if h.config != nil && strings.TrimSpace(h.config.OpenAI.Model) != "" { + return strings.TrimSpace(h.config.OpenAI.Model) + } + return "" +} + +func buildAuditAgentReviewInput(hitlMode, toolName string, payload map[string]interface{}) string { + review := map[string]interface{}{ + "hitlMode": normalizeHitlMode(hitlMode), + "toolName": strings.TrimSpace(toolName), + } + if payload != nil { + for _, k := range []string{"arguments", "argumentsObj", "command", hitlPayloadUserMessage, hitlPayloadThinking, hitlPayloadReasoningChain, hitlPayloadPlanning} { + if v, ok := payload[k]; ok && v != nil && fmt.Sprint(v) != "" { + review[k] = v + } + } + } + b, err := json.MarshalIndent(review, "", " ") + if err != nil { + return fmt.Sprintf(`{"hitlMode":%q,"toolName":%q}`, normalizeHitlMode(hitlMode), toolName) + } + return string(b) +} + +func parseAuditAgentLLMContent(content string) (hitlDecision, error) { + s := strings.TrimSpace(content) + if s == "" { + return hitlDecision{}, errors.New("empty content") + } + for _, candidate := range auditAgentJSONCandidates(s) { + dec, comment, editedArgs, err := parseAuditAgentDecisionObject(candidate) + if err == nil { + return hitlDecision{ + Decision: dec, + Comment: comment, + EditedArguments: editedArgs, + }, nil + } + } + return hitlDecision{}, fmt.Errorf("no valid decision json in response") +} + +func auditAgentJSONCandidates(s string) []string { + out := make([]string, 0, 4) + seen := make(map[string]struct{}) + add := func(c string) { + c = strings.TrimSpace(c) + if c == "" { + return + } + if _, ok := seen[c]; ok { + return + } + seen[c] = struct{}{} + out = append(out, c) + } + add(s) + add(stripMarkdownCodeFence(s)) + if obj := extractFirstJSONObject(s); obj != "" { + add(obj) + } + if obj := extractFirstJSONObject(stripMarkdownCodeFence(s)); obj != "" { + add(obj) + } + return out +} + +func stripMarkdownCodeFence(s string) string { + s = strings.TrimSpace(s) + for _, fence := range []string{"```json", "```JSON", "```"} { + if strings.HasPrefix(s, fence) { + s = strings.TrimPrefix(s, fence) + } + } + s = strings.TrimSuffix(s, "```") + return strings.TrimSpace(s) +} + +func extractFirstJSONObject(s string) string { + start := strings.Index(s, "{") + if start < 0 { + return "" + } + depth := 0 + inStr := false + esc := false + for i := start; i < len(s); i++ { + ch := s[i] + if inStr { + if esc { + esc = false + continue + } + if ch == '\\' { + esc = true + continue + } + if ch == '"' { + inStr = false + } + continue + } + switch ch { + case '"': + inStr = true + case '{': + depth++ + case '}': + depth-- + if depth == 0 { + return s[start : i+1] + } + } + } + return "" +} + +func parseAuditAgentDecisionObject(jsonText string) (decision, comment string, editedArgs map[string]interface{}, err error) { + var parsed map[string]interface{} + if err := json.Unmarshal([]byte(jsonText), &parsed); err != nil { + return "", "", nil, err + } + rawDecision := auditAgentPickString(parsed, "decision", "Decision", "result", "action", "verdict", "决策", "决定") + decision = normalizeAuditAgentDecision(rawDecision) + if decision == "" { + return "", "", nil, fmt.Errorf("missing decision") + } + comment = auditAgentPickString(parsed, "comment", "Comment", "reason", "message", "rationale", "备注", "理由", "说明") + editedArgs = auditAgentPickObject(parsed, "editedArguments", "edited_arguments", "editedArgs") + return decision, strings.TrimSpace(comment), editedArgs, nil +} + +func auditAgentPickString(m map[string]interface{}, keys ...string) string { + for _, k := range keys { + if v, ok := m[k]; ok && v != nil { + s := strings.TrimSpace(fmt.Sprint(v)) + if s != "" { + return s + } + } + } + return "" +} + +func auditAgentPickObject(m map[string]interface{}, keys ...string) map[string]interface{} { + for _, k := range keys { + v, ok := m[k] + if !ok || v == nil { + continue + } + switch t := v.(type) { + case map[string]interface{}: + if len(t) > 0 { + return t + } + case string: + s := strings.TrimSpace(t) + if s == "" || s == "{}" { + continue + } + var obj map[string]interface{} + if err := json.Unmarshal([]byte(s), &obj); err == nil && len(obj) > 0 { + return obj + } + } + } + return nil +} + +func normalizeAuditAgentDecision(v string) string { + d := strings.ToLower(strings.TrimSpace(v)) + switch d { + case "approve", "approved", "pass", "passed", "allow", "allowed", "yes", "ok", "accept", "accepted": + return "approve" + case "reject", "rejected", "deny", "denied", "no", "block", "blocked", "refuse", "refused": + return "reject" + } + switch strings.TrimSpace(v) { + case "通过", "批准", "允许", "同意", "放行": + return "approve" + case "拒绝", "驳回", "禁止", "否决": + return "reject" + } + return "" +} + +type hitlAuditStrategyReq struct { + AuditAgentPrompt string `json:"auditAgentPrompt"` + AuditAgentPromptReviewEdit string `json:"auditAgentPromptReviewEdit"` +} + +func (h *AgentHandler) GetHITLAuditStrategy(c *gin.Context) { + approvalPrompt := config.DefaultHitlAuditAgentPrompt() + reviewEditPrompt := config.DefaultHitlAuditAgentPromptReviewEdit() + approvalCustom := false + reviewEditCustom := false + if h.config != nil { + approvalPrompt = h.config.Hitl.EffectiveAuditAgentPromptForMode("approval") + reviewEditPrompt = h.config.Hitl.EffectiveAuditAgentPromptForMode("review_edit") + approvalCustom = strings.TrimSpace(h.config.Hitl.AuditAgentPrompt) != "" + reviewEditCustom = strings.TrimSpace(h.config.Hitl.AuditAgentPromptReviewEdit) != "" + } + c.JSON(http.StatusOK, gin.H{ + "auditAgentPrompt": approvalPrompt, + "auditAgentPromptCustom": approvalCustom, + "auditAgentPromptReviewEdit": reviewEditPrompt, + "auditAgentPromptReviewEditCustom": reviewEditCustom, + "defaultAuditAgentPrompt": config.DefaultHitlAuditAgentPrompt(), + "defaultAuditAgentPromptReviewEdit": config.DefaultHitlAuditAgentPromptReviewEdit(), + }) +} + +func (h *AgentHandler) UpdateHITLAuditStrategy(c *gin.Context) { + if h.hitlStrategySaver == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "HITL 策略持久化不可用"}) + return + } + var req hitlAuditStrategyReq + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + approvalPrompt := strings.TrimSpace(req.AuditAgentPrompt) + reviewEditPrompt := strings.TrimSpace(req.AuditAgentPromptReviewEdit) + if err := h.hitlStrategySaver.UpdateHitlAuditAgentStrategy(approvalPrompt, reviewEditPrompt); err != nil { + h.logger.Warn("保存审计 Agent 提示词失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if h.audit != nil { + h.audit.RecordOK(c, "hitl", "audit_strategy_update", "HITL 审计策略更新", "hitl_config", "audit_agent_prompt", nil) + } + if h.config != nil { + h.config.Hitl.AuditAgentPrompt = approvalPrompt + h.config.Hitl.AuditAgentPromptReviewEdit = reviewEditPrompt + } + c.JSON(http.StatusOK, gin.H{ + "ok": true, + "auditAgentPrompt": config.HitlConfig{AuditAgentPrompt: approvalPrompt}.EffectiveAuditAgentPromptForMode("approval"), + "auditAgentPromptCustom": approvalPrompt != "", + "auditAgentPromptReviewEdit": config.HitlConfig{AuditAgentPromptReviewEdit: reviewEditPrompt}.EffectiveAuditAgentPromptForMode("review_edit"), + "auditAgentPromptReviewEditCustom": reviewEditPrompt != "", + }) +} + +// HitlAuditStrategySaver 持久化审计 Agent 提示词到 config.yaml。 +type HitlAuditStrategySaver interface { + UpdateHitlAuditAgentStrategy(approvalPrompt, reviewEditPrompt string) error +} + +// SetHitlAuditStrategySaver 设置审计策略落盘。 +func (h *AgentHandler) SetHitlAuditStrategySaver(s HitlAuditStrategySaver) { + h.hitlStrategySaver = s +} diff --git a/internal/handler/hitl_audit_agent_test.go b/internal/handler/hitl_audit_agent_test.go new file mode 100644 index 00000000..8a7d9a4c --- /dev/null +++ b/internal/handler/hitl_audit_agent_test.go @@ -0,0 +1,88 @@ +package handler + +import ( + "strings" + "testing" +) + +func TestParseAuditAgentLLMContentApprove(t *testing.T) { + d, err := parseAuditAgentLLMContent(`{"decision":"approve","comment":"与任务一致"}`) + if err != nil { + t.Fatal(err) + } + if d.Decision != "approve" || d.Comment != "与任务一致" { + t.Fatalf("unexpected %+v", d) + } +} + +func TestParseAuditAgentLLMContentReject(t *testing.T) { + d, err := parseAuditAgentLLMContent("```json\n{\"decision\":\"reject\",\"comment\":\"风险过高\"}\n```") + if err != nil { + t.Fatal(err) + } + if d.Decision != "reject" { + t.Fatalf("expected reject, got %s", d.Decision) + } +} + +func TestParseAuditAgentLLMContentInvalid(t *testing.T) { + _, err := parseAuditAgentLLMContent(`{"decision":"maybe"}`) + if err == nil { + t.Fatal("expected error for invalid decision") + } +} + +func TestParseAuditAgentLLMContentProseWrapped(t *testing.T) { + d, err := parseAuditAgentLLMContent("好的,裁决如下:\n```json\n{\"decision\":\"approve\",\"comment\":\"只读 ls\"}\n```\n以上。") + if err != nil { + t.Fatal(err) + } + if d.Decision != "approve" { + t.Fatalf("expected approve, got %s", d.Decision) + } +} + +func TestParseAuditAgentLLMContentChineseDecision(t *testing.T) { + d, err := parseAuditAgentLLMContent(`{"decision":"通过","comment":"风险低"}`) + if err != nil { + t.Fatal(err) + } + if d.Decision != "approve" { + t.Fatalf("expected approve, got %s", d.Decision) + } +} + +func TestParseAuditAgentLLMContentWithEditedArguments(t *testing.T) { + d, err := parseAuditAgentLLMContent(`{"decision":"approve","comment":"收窄路径","editedArguments":{"path":"/safe"}}`) + if err != nil { + t.Fatal(err) + } + if d.Decision != "approve" { + t.Fatalf("expected approve, got %s", d.Decision) + } + if d.EditedArguments == nil || d.EditedArguments["path"] != "/safe" { + t.Fatalf("unexpected edited args: %+v", d.EditedArguments) + } +} + +func TestBuildAuditAgentReviewInputIncludesMode(t *testing.T) { + s := buildAuditAgentReviewInput("review_edit", "execute", map[string]interface{}{ + "arguments": `{"command":"pwd"}`, + }) + if !strings.Contains(s, "review_edit") || !strings.Contains(s, "execute") { + t.Fatalf("unexpected input: %s", s) + } +} + +func TestBuildAuditAgentReviewInput(t *testing.T) { + s := buildAuditAgentReviewInput("approval", "nmap", map[string]interface{}{ + "arguments": `{"target":"10.0.0.1"}`, + "userMessage": "扫描内网", + }) + if s == "" { + t.Fatal("expected non-empty input") + } + if !strings.Contains(s, "nmap") || !strings.Contains(s, "10.0.0.1") || !strings.Contains(s, "扫描内网") { + t.Fatalf("unexpected input: %s", s) + } +} diff --git a/internal/handler/hitl_cognition.go b/internal/handler/hitl_cognition.go new file mode 100644 index 00000000..6b24cb57 --- /dev/null +++ b/internal/handler/hitl_cognition.go @@ -0,0 +1,97 @@ +package handler + +import ( + "strings" +) + +type hitlCognitionState struct { + AssistantMessageID string + UserMessage string + Thinking string + ReasoningChain string + Planning string +} + +// GetHitlCognition 返回当前运行任务上缓存的本轮 HITL 上下文(不含会话历史)。 +func (m *AgentTaskManager) GetHitlCognition(conversationID string) hitlCognitionFields { + conversationID = strings.TrimSpace(conversationID) + if m == nil || conversationID == "" { + return hitlCognitionFields{} + } + m.mu.RLock() + defer m.mu.RUnlock() + t, ok := m.tasks[conversationID] + if !ok || t == nil || t.hitlCognition == nil { + return hitlCognitionFields{} + } + c := t.hitlCognition + return hitlCognitionFields{ + UserMessage: c.UserMessage, + Thinking: c.Thinking, + ReasoningChain: c.ReasoningChain, + Planning: c.Planning, + } +} + +// ResetHitlCognition 新任务开始时重置本轮 HITL 上下文。 +func (m *AgentTaskManager) ResetHitlCognition(conversationID, userMessage string) { + conversationID = strings.TrimSpace(conversationID) + if m == nil || conversationID == "" { + return + } + m.mu.Lock() + defer m.mu.Unlock() + t, ok := m.tasks[conversationID] + if !ok || t == nil { + return + } + t.hitlCognition = &hitlCognitionState{UserMessage: strings.TrimSpace(userMessage)} +} + +// SetHitlAssistantMessageID 记录当前助手消息 ID,供 HITL 与 DB 回退对齐。 +func (m *AgentTaskManager) SetHitlAssistantMessageID(conversationID, assistantMessageID string) { + conversationID = strings.TrimSpace(conversationID) + assistantMessageID = strings.TrimSpace(assistantMessageID) + if m == nil || conversationID == "" || assistantMessageID == "" { + return + } + m.mu.Lock() + defer m.mu.Unlock() + t, ok := m.tasks[conversationID] + if !ok || t == nil { + return + } + if t.hitlCognition == nil { + t.hitlCognition = &hitlCognitionState{} + } + t.hitlCognition.AssistantMessageID = assistantMessageID +} + +// UpdateHitlCognitionSnapshot 从进行中的进度流快照更新 thinking / reasoning / planning。 +func (m *AgentTaskManager) UpdateHitlCognitionSnapshot(conversationID, assistantMessageID, thinking, reasoningChain, planning string) { + conversationID = strings.TrimSpace(conversationID) + if m == nil || conversationID == "" { + return + } + m.mu.Lock() + defer m.mu.Unlock() + t, ok := m.tasks[conversationID] + if !ok || t == nil { + return + } + if t.hitlCognition == nil { + t.hitlCognition = &hitlCognitionState{} + } + if id := strings.TrimSpace(assistantMessageID); id != "" { + t.hitlCognition.AssistantMessageID = id + } + if s := strings.TrimSpace(thinking); s != "" { + t.hitlCognition.Thinking = s + } + if s := strings.TrimSpace(reasoningChain); s != "" { + t.hitlCognition.ReasoningChain = s + } + if s := strings.TrimSpace(planning); s != "" { + t.hitlCognition.Planning = s + } +} diff --git a/internal/handler/hitl_context.go b/internal/handler/hitl_context.go new file mode 100644 index 00000000..e8d551ad --- /dev/null +++ b/internal/handler/hitl_context.go @@ -0,0 +1,102 @@ +package handler + +import ( + "strings" +) + +const ( + hitlPayloadUserMessage = "userMessage" + hitlPayloadThinking = "thinking" + hitlPayloadReasoningChain = "reasoningChain" + hitlPayloadPlanning = "planning" +) + +type hitlCognitionFields struct { + UserMessage string + Thinking string + ReasoningChain string + Planning string +} + +func (h *AgentHandler) enrichHitlApprovalPayload(conversationID, assistantMessageID string, payload map[string]interface{}) { + if h == nil || payload == nil { + return + } + cog := h.collectHitlCognition(conversationID, assistantMessageID) + if s := strings.TrimSpace(cog.UserMessage); s != "" { + payload[hitlPayloadUserMessage] = s + } + if s := strings.TrimSpace(cog.Thinking); s != "" { + payload[hitlPayloadThinking] = s + } + if s := strings.TrimSpace(cog.ReasoningChain); s != "" { + payload[hitlPayloadReasoningChain] = s + } + if s := strings.TrimSpace(cog.Planning); s != "" { + payload[hitlPayloadPlanning] = s + } +} + +func (h *AgentHandler) collectHitlCognition(conversationID, assistantMessageID string) hitlCognitionFields { + var out hitlCognitionFields + if h.tasks != nil { + out = h.tasks.GetHitlCognition(conversationID) + } + if strings.TrimSpace(out.UserMessage) == "" && h.db != nil { + if msg, err := h.db.GetTurnUserMessage(conversationID, assistantMessageID); err == nil { + out.UserMessage = msg + } + } + if h.db != nil && assistantMessageID != "" { + dbCog, err := h.db.GetAssistantCognitionTexts(assistantMessageID) + if err == nil { + if strings.TrimSpace(out.Thinking) == "" { + out.Thinking = dbCog.Thinking + } + if strings.TrimSpace(out.ReasoningChain) == "" { + out.ReasoningChain = dbCog.ReasoningChain + } + if strings.TrimSpace(out.Planning) == "" { + out.Planning = dbCog.Planning + } + } + } + return out +} + +func snapshotHitlCognitionFromStreams(thinkingStreams map[string]*thinkingBuf, respPlan *responsePlanAgg) (thinking, reasoningChain, planning string) { + if len(thinkingStreams) > 0 { + var thinkingParts, reasoningParts []string + for _, tb := range thinkingStreams { + if tb == nil { + continue + } + content := strings.TrimSpace(tb.b.String()) + if content == "" { + continue + } + if tb.persistAs == "reasoning_chain" { + reasoningParts = append(reasoningParts, content) + } else { + thinkingParts = append(thinkingParts, content) + } + } + thinking = strings.Join(thinkingParts, "\n\n") + reasoningChain = strings.Join(reasoningParts, "\n\n") + } + if respPlan != nil { + planning = strings.TrimSpace(respPlan.b.String()) + } + return thinking, reasoningChain, planning +} + +func (h *AgentHandler) syncHitlCognitionFromProgress(conversationID, assistantMessageID string, thinkingStreams map[string]*thinkingBuf, respPlan *responsePlanAgg) { + if h == nil || h.tasks == nil { + return + } + thinking, reasoning, planning := snapshotHitlCognitionFromStreams(thinkingStreams, respPlan) + if thinking == "" && reasoning == "" && planning == "" { + return + } + h.tasks.UpdateHitlCognitionSnapshot(conversationID, assistantMessageID, thinking, reasoning, planning) +} diff --git a/internal/handler/hitl_context_test.go b/internal/handler/hitl_context_test.go new file mode 100644 index 00000000..cdf3870c --- /dev/null +++ b/internal/handler/hitl_context_test.go @@ -0,0 +1,46 @@ +package handler + +import ( + "os" + "path/filepath" + "testing" + + "cyberstrike-ai/internal/database" + + "go.uber.org/zap" +) + +func TestEnrichHitlApprovalPayload(t *testing.T) { + tmp := t.TempDir() + db, err := database.NewDB(filepath.Join(tmp, "test.sqlite"), zap.NewNop()) + if err != nil { + t.Fatalf("db: %v", err) + } + defer os.RemoveAll(tmp) + + conv, err := db.CreateConversation("hitl ctx", database.ConversationCreateMeta{}) + if err != nil { + t.Fatalf("conv: %v", err) + } + if _, err := db.AddMessage(conv.ID, "user", "scan 10.0.0.1 please", nil); err != nil { + t.Fatalf("user msg: %v", err) + } + asst, err := db.AddMessage(conv.ID, "assistant", "", nil) + if err != nil { + t.Fatalf("asst msg: %v", err) + } + if err := db.AddProcessDetail(asst.ID, conv.ID, "thinking", "need port scan first", nil); err != nil { + t.Fatalf("detail: %v", err) + } + + h := &AgentHandler{db: db, tasks: NewAgentTaskManager()} + payload := map[string]interface{}{"toolName": "nmap", "arguments": "{}"} + h.enrichHitlApprovalPayload(conv.ID, asst.ID, payload) + + if got := payload["userMessage"]; got != "scan 10.0.0.1 please" { + t.Fatalf("userMessage=%v", got) + } + if got := payload["thinking"]; got != "need port scan first" { + t.Fatalf("thinking=%v", got) + } +} diff --git a/internal/handler/hitl_execution.go b/internal/handler/hitl_execution.go new file mode 100644 index 00000000..8d44b6d1 --- /dev/null +++ b/internal/handler/hitl_execution.go @@ -0,0 +1,132 @@ +package handler + +import ( + "encoding/json" + "strings" + "time" +) + +const hitlPayloadExecutionResult = "executionResult" + +type hitlExecutionResult struct { + Success bool `json:"success"` + Result string `json:"result,omitempty"` + ToolName string `json:"toolName,omitempty"` + ToolCallID string `json:"toolCallId,omitempty"` + RecordedAt time.Time `json:"recordedAt"` +} + +type hitlApprovedExecTrack struct { + InterruptID string + ConversationID string + ToolName string + ToolCallID string +} + +// TrackApprovedHitlExecution 审批通过后登记,待 tool_result 回写执行结果。 +func (m *HITLManager) TrackApprovedHitlExecution(interruptID, conversationID, toolName, toolCallID string) { + if m == nil { + return + } + interruptID = strings.TrimSpace(interruptID) + conversationID = strings.TrimSpace(conversationID) + if interruptID == "" || conversationID == "" { + return + } + m.mu.Lock() + defer m.mu.Unlock() + if m.approvedExec == nil { + m.approvedExec = make(map[string][]hitlApprovedExecTrack) + } + m.approvedExec[conversationID] = append(m.approvedExec[conversationID], hitlApprovedExecTrack{ + InterruptID: interruptID, + ConversationID: conversationID, + ToolName: strings.TrimSpace(toolName), + ToolCallID: strings.TrimSpace(toolCallID), + }) +} + +func (m *HITLManager) popApprovedInterruptForTool(conversationID, toolCallID, toolName string) string { + if m == nil { + return "" + } + conversationID = strings.TrimSpace(conversationID) + toolCallID = strings.TrimSpace(toolCallID) + toolName = strings.TrimSpace(toolName) + m.mu.Lock() + defer m.mu.Unlock() + queue := m.approvedExec[conversationID] + if len(queue) == 0 { + return "" + } + idx := -1 + if toolCallID != "" { + for i, t := range queue { + if t.ToolCallID == toolCallID { + idx = i + break + } + } + } + if idx < 0 && toolName != "" { + for i, t := range queue { + if strings.EqualFold(t.ToolName, toolName) { + idx = i + break + } + } + } + if idx < 0 { + return "" + } + id := queue[idx].InterruptID + queue = append(queue[:idx], queue[idx+1:]...) + if len(queue) == 0 { + delete(m.approvedExec, conversationID) + } else { + m.approvedExec[conversationID] = queue + } + return id +} + +func mergeHitlPayloadExecutionResult(payloadJSON string, exec hitlExecutionResult) (string, error) { + root := make(map[string]interface{}) + if strings.TrimSpace(payloadJSON) != "" { + _ = json.Unmarshal([]byte(payloadJSON), &root) + } + if root == nil { + root = make(map[string]interface{}) + } + root[hitlPayloadExecutionResult] = exec + out, err := json.Marshal(root) + if err != nil { + return payloadJSON, err + } + return string(out), nil +} + +func (h *AgentHandler) recordHitlToolExecutionResult(conversationID, toolCallID, toolName string, success bool, result string) { + if h == nil || h.hitlManager == nil || h.db == nil { + return + } + interruptID := h.hitlManager.popApprovedInterruptForTool(conversationID, toolCallID, toolName) + if interruptID == "" { + return + } + var payloadJSON string + err := h.db.QueryRow(`SELECT payload FROM hitl_interrupts WHERE id = ?`, interruptID).Scan(&payloadJSON) + if err != nil { + return + } + merged, err := mergeHitlPayloadExecutionResult(payloadJSON, hitlExecutionResult{ + Success: success, + Result: strings.TrimSpace(result), + ToolName: strings.TrimSpace(toolName), + ToolCallID: strings.TrimSpace(toolCallID), + RecordedAt: time.Now(), + }) + if err != nil { + return + } + _, _ = h.db.Exec(`UPDATE hitl_interrupts SET payload = ? WHERE id = ?`, merged, interruptID) +} diff --git a/internal/handler/hitl_execution_test.go b/internal/handler/hitl_execution_test.go new file mode 100644 index 00000000..1c620366 --- /dev/null +++ b/internal/handler/hitl_execution_test.go @@ -0,0 +1,39 @@ +package handler + +import ( + "encoding/json" + "testing" +) + +func TestMergeHitlPayloadExecutionResult(t *testing.T) { + merged, err := mergeHitlPayloadExecutionResult(`{"userMessage":"hi","toolName":"nmap"}`, hitlExecutionResult{ + Success: true, + Result: "open ports: 80", + }) + if err != nil { + t.Fatal(err) + } + var root map[string]interface{} + if err := json.Unmarshal([]byte(merged), &root); err != nil { + t.Fatal(err) + } + if root["userMessage"] != "hi" { + t.Fatalf("userMessage lost: %v", root["userMessage"]) + } + exec, ok := root["executionResult"].(map[string]interface{}) + if !ok || exec["success"] != true { + t.Fatalf("executionResult missing: %v", root["executionResult"]) + } +} + +func TestPopApprovedInterruptForTool(t *testing.T) { + m := NewHITLManager(nil, nil) + m.TrackApprovedHitlExecution("hitl_a", "conv1", "nmap", "tc1") + m.TrackApprovedHitlExecution("hitl_b", "conv1", "exec", "") + if id := m.popApprovedInterruptForTool("conv1", "tc1", "nmap"); id != "hitl_a" { + t.Fatalf("tc1 match=%q", id) + } + if id := m.popApprovedInterruptForTool("conv1", "", "exec"); id != "hitl_b" { + t.Fatalf("tool name match=%q", id) + } +} diff --git a/internal/handler/hitl_logs.go b/internal/handler/hitl_logs.go new file mode 100644 index 00000000..246ecb98 --- /dev/null +++ b/internal/handler/hitl_logs.go @@ -0,0 +1,201 @@ +package handler + +import ( + "database/sql" + "errors" + "math" + "net/http" + "strconv" + "strings" + "time" + + "github.com/gin-gonic/gin" +) + +func normalizeHitlReviewer(v string) string { + switch strings.ToLower(strings.TrimSpace(v)) { + case "audit_agent", "agent", "ai": + return "audit_agent" + default: + return "human" + } +} + +func normalizeHitlDecidedBy(v string) string { + switch strings.ToLower(strings.TrimSpace(v)) { + case "audit_agent", "agent", "ai": + return "audit_agent" + case "system", "timeout": + return "system" + case "manual": + return "manual" + default: + return "human" + } +} + +func (m *HITLManager) migrateHitlSchemaColumns() { + _, _ = m.db.Exec(`ALTER TABLE hitl_interrupts ADD COLUMN decided_by TEXT NOT NULL DEFAULT 'human'`) + _, _ = m.db.Exec(`ALTER TABLE hitl_conversation_configs ADD COLUMN reviewer TEXT NOT NULL DEFAULT 'human'`) +} + +func hitlInterruptRowToMap( + id, cid, mode, toolName, toolCallID, payload, rowStatus, decidedBy string, + messageID sql.NullString, + decision, comment sql.NullString, + createdAt time.Time, + decidedAt sql.NullTime, +) map[string]interface{} { + msgID := "" + if messageID.Valid { + msgID = messageID.String + } + return map[string]interface{}{ + "id": id, + "conversationId": cid, + "messageId": msgID, + "mode": mode, + "toolName": toolName, + "toolCallId": toolCallID, + "payload": payload, + "status": rowStatus, + "decision": decision.String, + "comment": comment.String, + "decidedBy": decidedBy, + "createdAt": createdAt, + "decidedAt": func() interface{} { + if decidedAt.Valid { + return decidedAt.Time + } + return nil + }(), + } +} + +func (h *AgentHandler) buildHitlListQuery(logs bool) (string, []interface{}) { + q := `SELECT id, conversation_id, message_id, mode, tool_name, tool_call_id, payload, status, decision, decision_comment, COALESCE(decided_by,'human'), created_at, decided_at FROM hitl_interrupts WHERE 1=1` + args := []interface{}{} + if logs { + q += " AND status != 'pending'" + } else { + q += " AND status = 'pending'" + } + return q, args +} + +func (h *AgentHandler) appendHitlListFilters(q string, args []interface{}, c *gin.Context) (string, []interface{}) { + conversationID := strings.TrimSpace(c.Query("conversationId")) + toolName := strings.TrimSpace(c.Query("toolName")) + decision := strings.TrimSpace(c.Query("decision")) + decidedBy := strings.TrimSpace(c.Query("decidedBy")) + status := strings.TrimSpace(c.Query("status")) + search := strings.TrimSpace(c.Query("q")) + + if conversationID != "" { + q += " AND conversation_id = ?" + args = append(args, conversationID) + } + if toolName != "" { + q += " AND tool_name LIKE ?" + args = append(args, "%"+toolName+"%") + } + if decision != "" && decision != "all" { + q += " AND decision = ?" + args = append(args, decision) + } + if decidedBy != "" && decidedBy != "all" { + q += " AND COALESCE(decided_by,'human') = ?" + args = append(args, normalizeHitlDecidedBy(decidedBy)) + } + if status != "" && status != "all" { + q += " AND status = ?" + args = append(args, status) + } + if search != "" { + like := "%" + search + "%" + q += " AND (id LIKE ? OR conversation_id LIKE ? OR tool_name LIKE ? OR payload LIKE ? OR COALESCE(decision_comment,'') LIKE ?)" + args = append(args, like, like, like, like, like) + } + return q, args +} + +func (h *AgentHandler) scanHitlInterruptRows(rows *sql.Rows) ([]map[string]interface{}, error) { + items := make([]map[string]interface{}, 0) + for rows.Next() { + var id, cid, mode, toolName, toolCallID, payload, rowStatus, decidedBy string + var messageID sql.NullString + var decision, comment sql.NullString + var createdAt time.Time + var decidedAt sql.NullTime + if err := rows.Scan(&id, &cid, &messageID, &mode, &toolName, &toolCallID, &payload, &rowStatus, &decision, &comment, &decidedBy, &createdAt, &decidedAt); err != nil { + continue + } + items = append(items, hitlInterruptRowToMap(id, cid, mode, toolName, toolCallID, payload, rowStatus, decidedBy, messageID, decision, comment, createdAt, decidedAt)) + } + return items, nil +} + +func (h *AgentHandler) countHitlQuery(baseQ string, args []interface{}) (int, error) { + countQ := "SELECT COUNT(*) FROM (" + baseQ + ") AS hitl_cnt" + var total int + if err := h.db.QueryRow(countQ, args...).Scan(&total); err != nil { + return 0, err + } + return total, nil +} + +func (h *AgentHandler) ListHITLLogs(c *gin.Context) { + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + if page < 1 { + page = 1 + } + pageSize, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20")) + pageSize = int(math.Max(1, math.Min(float64(pageSize), 200))) + offset := (page - 1) * pageSize + + q, args := h.buildHitlListQuery(true) + q, args = h.appendHitlListFilters(q, args, c) + total, err := h.countHitlQuery(q, args) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + q += " ORDER BY COALESCE(decided_at, created_at) DESC LIMIT ? OFFSET ?" + args = append(args, pageSize, offset) + rows, err := h.db.Query(q, args...) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + defer rows.Close() + items, err := h.scanHitlInterruptRows(rows) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"items": items, "page": page, "pageSize": pageSize, "total": total}) +} + +func (h *AgentHandler) GetHITLLog(c *gin.Context) { + id := strings.TrimSpace(c.Param("id")) + if id == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"}) + return + } + q := `SELECT id, conversation_id, message_id, mode, tool_name, tool_call_id, payload, status, decision, decision_comment, COALESCE(decided_by,'human'), created_at, decided_at FROM hitl_interrupts WHERE id = ?` + var rowID, cid, mode, toolName, toolCallID, payload, rowStatus, decidedBy string + var messageID sql.NullString + var decision, comment sql.NullString + var createdAt time.Time + var decidedAt sql.NullTime + err := h.db.QueryRow(q, id).Scan(&rowID, &cid, &messageID, &mode, &toolName, &toolCallID, &payload, &rowStatus, &decision, &comment, &decidedBy, &createdAt, &decidedAt) + if errors.Is(err, sql.ErrNoRows) { + c.JSON(http.StatusNotFound, gin.H{"error": "not found"}) + return + } + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, hitlInterruptRowToMap(rowID, cid, mode, toolName, toolCallID, payload, rowStatus, decidedBy, messageID, decision, comment, createdAt, decidedAt)) +} diff --git a/internal/handler/task_manager.go b/internal/handler/task_manager.go index 4b19b40f..5869007a 100644 --- a/internal/handler/task_manager.go +++ b/internal/handler/task_manager.go @@ -43,6 +43,9 @@ type AgentTask struct { // activeEinoExecuteAbortNote AbortActiveEinoExecute 写入的用户说明,由 execute 收尾时合并进工具结果 activeEinoExecuteAbortNote string + // hitlCognition 本轮运行中供 HITL/审计 Agent 读取的上下文(用户原话 + 思考,不含会话历史) + hitlCognition *hitlCognitionState + cancel func(error) } @@ -354,6 +357,7 @@ func (m *AgentTaskManager) StartTask(conversationID, message string, cancel cont } m.tasks[conversationID] = task + task.hitlCognition = &hitlCognitionState{UserMessage: strings.TrimSpace(message)} return task, nil }