mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-06-30 17:55:32 +02:00
Add files via upload
This commit is contained in:
@@ -77,6 +77,13 @@ type responsePlanAgg struct {
|
||||
b strings.Builder
|
||||
}
|
||||
|
||||
// thinkingBuf aggregates thinking_stream_* / reasoning_chain_stream_* before flush to process_details.
|
||||
type thinkingBuf struct {
|
||||
b strings.Builder
|
||||
meta map[string]interface{}
|
||||
persistAs string // "thinking" | "reasoning_chain"
|
||||
}
|
||||
|
||||
func normalizeProcessDetailText(s string) string {
|
||||
s = strings.ReplaceAll(s, "\r\n", "\n")
|
||||
s = strings.ReplaceAll(s, "\r", "\n")
|
||||
@@ -179,6 +186,8 @@ type AgentHandler struct {
|
||||
batchCronParser cron.Parser
|
||||
// hitlWhitelistSaver 侧栏「应用」HITL 时将会话增量白名单合并写入 config.yaml(可选)
|
||||
hitlWhitelistSaver HitlToolWhitelistSaver
|
||||
hitlStrategySaver HitlAuditStrategySaver
|
||||
auditLLM *openai.Client
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
@@ -218,9 +227,10 @@ func (h *AgentHandler) cancelActiveMCPToolForConversation(conversationID string)
|
||||
}
|
||||
}
|
||||
|
||||
// HitlToolWhitelistSaver 合并 HITL 免审批工具到全局配置并落盘
|
||||
// HitlToolWhitelistSaver 合并/设置 HITL 免审批工具到全局配置并落盘
|
||||
type HitlToolWhitelistSaver interface {
|
||||
MergeHitlToolWhitelistIntoConfig(add []string) error
|
||||
SetHitlToolWhitelist(tools []string) error
|
||||
}
|
||||
|
||||
// NewAgentHandler 创建新的Agent处理器
|
||||
@@ -236,6 +246,11 @@ func NewAgentHandler(agent *agent.Agent, db *database.DB, cfg *config.Config, lo
|
||||
bus := NewTaskEventBus()
|
||||
tm := NewAgentTaskManager()
|
||||
tm.SetTaskEventBus(bus)
|
||||
llmHTTP := &http.Client{Timeout: 2 * time.Minute}
|
||||
var llmCfg *config.OpenAIConfig
|
||||
if cfg != nil {
|
||||
llmCfg = &cfg.OpenAI
|
||||
}
|
||||
handler := &AgentHandler{
|
||||
agent: agent,
|
||||
db: db,
|
||||
@@ -246,6 +261,7 @@ func NewAgentHandler(agent *agent.Agent, db *database.DB, cfg *config.Config, lo
|
||||
config: cfg,
|
||||
hitlManager: NewHITLManager(db, logger),
|
||||
batchCronParser: cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor),
|
||||
auditLLM: openai.NewClient(llmCfg, llmHTTP, logger),
|
||||
}
|
||||
tm.SetToolCanceler(handler.cancelActiveMCPToolForConversation)
|
||||
if err := handler.hitlManager.EnsureSchema(); err != nil {
|
||||
@@ -320,6 +336,7 @@ func chatReasoningToClientIntent(r *ChatReasoningRequest) *reasoning.ClientInten
|
||||
type HITLRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Reviewer string `json:"reviewer,omitempty"` // human | audit_agent
|
||||
SensitiveTools []string `json:"sensitiveTools,omitempty"`
|
||||
TimeoutSeconds int `json:"timeoutSeconds,omitempty"`
|
||||
}
|
||||
@@ -849,11 +866,6 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
|
||||
// thinking_stream_*(ReAct 等助手正文流)与 reasoning_chain_stream_*(Eino ReasoningContent):
|
||||
// 不逐条落库,按 streamId 聚合,flush 时分别落 thinking / reasoning_chain。
|
||||
type thinkingBuf struct {
|
||||
b strings.Builder
|
||||
meta map[string]interface{}
|
||||
persistAs string // "thinking" | "reasoning_chain"
|
||||
}
|
||||
thinkingStreams := make(map[string]*thinkingBuf) // streamId -> buf
|
||||
flushedThinking := make(map[string]bool) // streamId -> flushed
|
||||
seenToolCallSigs := make(map[string]string) // toolCallId -> payload signature
|
||||
@@ -866,6 +878,12 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
// response_start + response_delta:前端时间线显示为「📝 规划中」(monitor.js),不落逐条 delta;
|
||||
// 聚合为一条 planning 写入 process_details,刷新后与线上一致。
|
||||
var respPlan responsePlanAgg
|
||||
if assistantMessageID != "" {
|
||||
h.tasks.SetHitlAssistantMessageID(conversationID, assistantMessageID)
|
||||
}
|
||||
syncHitlCognition := func() {
|
||||
h.syncHitlCognitionFromProgress(conversationID, assistantMessageID, thinkingStreams, &respPlan)
|
||||
}
|
||||
flushResponsePlan := func() {
|
||||
if assistantMessageID == "" {
|
||||
return
|
||||
@@ -885,6 +903,7 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "planning", content, data); err != nil {
|
||||
h.logger.Warn("保存过程详情失败", zap.Error(err), zap.String("eventType", "planning"))
|
||||
}
|
||||
syncHitlCognition()
|
||||
respPlan.meta = nil
|
||||
respPlan.b.Reset()
|
||||
}
|
||||
@@ -921,6 +940,7 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
}
|
||||
flushedThinking[sid] = true
|
||||
}
|
||||
syncHitlCognition()
|
||||
}
|
||||
|
||||
return func(eventType, message string, data interface{}) {
|
||||
@@ -981,6 +1001,25 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
}
|
||||
}
|
||||
|
||||
if eventType == "tool_result" {
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
toolName, _ := dataMap["toolName"].(string)
|
||||
toolCallID, _ := dataMap["toolCallId"].(string)
|
||||
success := true
|
||||
if v, ok := dataMap["success"].(bool); ok {
|
||||
success = v
|
||||
}
|
||||
resultText := ""
|
||||
if r, ok := dataMap["result"].(string); ok {
|
||||
resultText = r
|
||||
}
|
||||
if strings.TrimSpace(resultText) == "" {
|
||||
resultText = message
|
||||
}
|
||||
h.recordHitlToolExecutionResult(conversationID, toolCallID, toolName, success, resultText)
|
||||
}
|
||||
}
|
||||
|
||||
// 处理知识检索日志记录
|
||||
if eventType == "tool_result" && h.knowledgeManager != nil {
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
@@ -1188,6 +1227,7 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
respPlan.meta[k] = v
|
||||
}
|
||||
}
|
||||
syncHitlCognition()
|
||||
return
|
||||
}
|
||||
if eventType == "response" {
|
||||
@@ -1257,6 +1297,7 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
}
|
||||
}
|
||||
}
|
||||
syncHitlCognition()
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -1766,6 +1766,20 @@ func mergeHitlToolWhitelistSlice(existing, add []string) []string {
|
||||
return out
|
||||
}
|
||||
|
||||
// SetHitlToolWhitelist 将全局免审批工具白名单整表写入 config.yaml(替换,非合并)。
|
||||
func (h *ConfigHandler) SetHitlToolWhitelist(tools []string) error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.config.Hitl.ToolWhitelist = mergeHitlToolWhitelistSlice(nil, tools)
|
||||
if err := h.saveConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
h.logger.Info("HITL 全局工具白名单已写入配置文件",
|
||||
zap.Int("count", len(h.config.Hitl.ToolWhitelist)),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// MergeHitlToolWhitelistIntoConfig 将会话侧栏提交的免审批工具名合并进内存配置并写入 config.yaml(与全局白名单去重规则一致:小写键、保留首次出现的原始大小写)。
|
||||
func (h *ConfigHandler) MergeHitlToolWhitelistIntoConfig(add []string) error {
|
||||
h.mu.Lock()
|
||||
@@ -1786,6 +1800,21 @@ func updateHitlConfig(doc *yaml.Node, cfg config.HitlConfig) {
|
||||
hitlNode := ensureMap(root, "hitl")
|
||||
// flow 样式 [a, b, c] 单行展示,工具多时比块序列省行数
|
||||
setFlowStringSliceInMap(hitlNode, "tool_whitelist", cfg.ToolWhitelist)
|
||||
setStringInMap(hitlNode, "audit_agent_prompt", cfg.AuditAgentPrompt)
|
||||
setStringInMap(hitlNode, "audit_agent_prompt_review_edit", cfg.AuditAgentPromptReviewEdit)
|
||||
}
|
||||
|
||||
// UpdateHitlAuditAgentStrategy 更新审批/审查编辑两套审计 Agent 提示词并写入 config.yaml。
|
||||
func (h *ConfigHandler) UpdateHitlAuditAgentStrategy(approvalPrompt, reviewEditPrompt string) error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.config.Hitl.AuditAgentPrompt = strings.TrimSpace(approvalPrompt)
|
||||
h.config.Hitl.AuditAgentPromptReviewEdit = strings.TrimSpace(reviewEditPrompt)
|
||||
if err := h.saveConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
h.logger.Info("HITL 审计 Agent 提示词已写入配置文件")
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateRobotsConfig(doc *yaml.Node, cfg config.RobotsConfig) {
|
||||
|
||||
+114
-60
@@ -23,6 +23,7 @@ import (
|
||||
type hitlRuntimeConfig struct {
|
||||
Enabled bool
|
||||
Mode string
|
||||
Reviewer string
|
||||
SensitiveTools map[string]struct{}
|
||||
Timeout time.Duration
|
||||
}
|
||||
@@ -49,6 +50,8 @@ type HITLManager struct {
|
||||
mu sync.RWMutex
|
||||
runtime map[string]hitlRuntimeConfig
|
||||
pending map[string]*pendingInterrupt
|
||||
// approvedExec 审批通过、待回写 tool_result 的队列(按会话 FIFO)
|
||||
approvedExec map[string][]hitlApprovedExecTrack
|
||||
}
|
||||
|
||||
func NewHITLManager(db *database.DB, logger *zap.Logger) *HITLManager {
|
||||
@@ -90,6 +93,7 @@ CREATE TABLE IF NOT EXISTS hitl_conversation_configs (
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.migrateHitlSchemaColumns()
|
||||
|
||||
// On startup, cancel all orphaned pending interrupts from previous process.
|
||||
// Their in-memory channels are gone, so they can never be resolved.
|
||||
@@ -141,6 +145,7 @@ func (m *HITLManager) ActivateConversation(conversationID string, req *HITLReque
|
||||
m.runtime[conversationID] = hitlRuntimeConfig{
|
||||
Enabled: true,
|
||||
Mode: normalizeHitlMode(req.Mode),
|
||||
Reviewer: normalizeHitlReviewer(req.Reviewer),
|
||||
SensitiveTools: tools,
|
||||
Timeout: timeout,
|
||||
}
|
||||
@@ -362,22 +367,22 @@ func (m *HITLManager) SaveConversationConfig(conversationID string, req *HITLReq
|
||||
timeout = 0
|
||||
}
|
||||
_, err := m.db.Exec(`INSERT INTO hitl_conversation_configs
|
||||
(conversation_id, enabled, mode, sensitive_tools, timeout_seconds, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
(conversation_id, enabled, mode, reviewer, sensitive_tools, timeout_seconds, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(conversation_id) DO UPDATE SET
|
||||
enabled=excluded.enabled, mode=excluded.mode, sensitive_tools=excluded.sensitive_tools, timeout_seconds=excluded.timeout_seconds, updated_at=excluded.updated_at`,
|
||||
conversationID, boolToInt(req.Enabled), mode, string(tools), timeout, time.Now())
|
||||
enabled=excluded.enabled, mode=excluded.mode, reviewer=excluded.reviewer, sensitive_tools=excluded.sensitive_tools, timeout_seconds=excluded.timeout_seconds, updated_at=excluded.updated_at`,
|
||||
conversationID, boolToInt(req.Enabled), mode, normalizeHitlReviewer(req.Reviewer), string(tools), timeout, time.Now())
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *HITLManager) LoadConversationConfig(conversationID string) (*HITLRequest, error) {
|
||||
var enabledInt int
|
||||
var mode, toolsJSON string
|
||||
var mode, reviewer, toolsJSON string
|
||||
var timeout int
|
||||
err := m.db.QueryRow(`SELECT enabled, mode, sensitive_tools, timeout_seconds FROM hitl_conversation_configs WHERE conversation_id = ?`, conversationID).
|
||||
Scan(&enabledInt, &mode, &toolsJSON, &timeout)
|
||||
err := m.db.QueryRow(`SELECT enabled, mode, COALESCE(reviewer,'human'), sensitive_tools, timeout_seconds FROM hitl_conversation_configs WHERE conversation_id = ?`, conversationID).
|
||||
Scan(&enabledInt, &mode, &reviewer, &toolsJSON, &timeout)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return &HITLRequest{Enabled: false, Mode: "off", SensitiveTools: []string{}, TimeoutSeconds: 0}, nil
|
||||
return &HITLRequest{Enabled: false, Mode: "off", Reviewer: "human", SensitiveTools: []string{}, TimeoutSeconds: 0}, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -390,6 +395,7 @@ func (m *HITLManager) LoadConversationConfig(conversationID string) (*HITLReques
|
||||
return &HITLRequest{
|
||||
Enabled: enabledInt == 1,
|
||||
Mode: mode,
|
||||
Reviewer: normalizeHitlReviewer(reviewer),
|
||||
SensitiveTools: tools,
|
||||
TimeoutSeconds: timeout,
|
||||
}, nil
|
||||
@@ -413,15 +419,15 @@ func (m *HITLManager) waitDecision(ctx context.Context, p *pendingInterrupt, tim
|
||||
if p.Mode != "review_edit" && len(d.EditedArguments) > 0 {
|
||||
d.EditedArguments = nil
|
||||
}
|
||||
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='decided', decision=?, decision_comment=?, decided_at=? WHERE id=?`,
|
||||
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='decided', decision=?, decision_comment=?, decided_at=?, decided_by='human' WHERE id=?`,
|
||||
d.Decision, d.Comment, time.Now(), p.InterruptID)
|
||||
return d, nil
|
||||
case <-timeoutCh:
|
||||
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='timeout', decision='approve', decision_comment='timeout auto approve', decided_at=? WHERE id=?`,
|
||||
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='timeout', decision='approve', decision_comment='timeout auto approve', decided_at=?, decided_by='system' WHERE id=?`,
|
||||
time.Now(), p.InterruptID)
|
||||
return hitlDecision{Decision: "approve", Comment: "timeout auto approve"}, nil
|
||||
case <-ctx.Done():
|
||||
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='cancelled', decision='reject', decision_comment='task cancelled', decided_at=? WHERE id=?`,
|
||||
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='cancelled', decision='reject', decision_comment='task cancelled', decided_at=?, decided_by='system' WHERE id=?`,
|
||||
time.Now(), p.InterruptID)
|
||||
return hitlDecision{Decision: "reject", Comment: "task cancelled"}, ctx.Err()
|
||||
}
|
||||
@@ -445,12 +451,57 @@ func (h *AgentHandler) waitHITLApproval(runCtx context.Context, cancelRun contex
|
||||
if !need {
|
||||
return nil, nil
|
||||
}
|
||||
h.enrichHitlApprovalPayload(conversationID, assistantMessageID, payload)
|
||||
payloadRaw, _ := json.Marshal(payload)
|
||||
p, err := h.hitlManager.CreatePendingInterrupt(conversationID, assistantMessageID, cfg.Mode, toolName, toolCallID, string(payloadRaw))
|
||||
if err != nil {
|
||||
h.logger.Warn("创建 HITL 中断失败", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cfg.Reviewer == "audit_agent" {
|
||||
ad := h.auditAgentReview(runCtx, cfg.Mode, toolName, payload)
|
||||
now := time.Now()
|
||||
_, _ = h.db.Exec(`UPDATE hitl_interrupts SET status='decided', decision=?, decision_comment=?, decided_at=?, decided_by='audit_agent' WHERE id=?`,
|
||||
ad.Decision, ad.Comment, now, p.InterruptID)
|
||||
if sendEventFunc != nil {
|
||||
sendEventFunc("hitl_audit_agent", "审计 Agent 已裁决", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"interruptId": p.InterruptID,
|
||||
"toolName": toolName,
|
||||
"mode": cfg.Mode,
|
||||
"decision": ad.Decision,
|
||||
"comment": ad.Comment,
|
||||
"editedArgs": ad.EditedArguments,
|
||||
"decidedBy": "audit_agent",
|
||||
})
|
||||
}
|
||||
if ad.Decision == "reject" {
|
||||
if sendEventFunc != nil {
|
||||
sendEventFunc("hitl_rejected", "审计 Agent 拒绝本次工具调用", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"interruptId": p.InterruptID,
|
||||
"toolName": toolName,
|
||||
"comment": ad.Comment,
|
||||
"decidedBy": "audit_agent",
|
||||
})
|
||||
}
|
||||
return &ad, nil
|
||||
}
|
||||
if sendEventFunc != nil {
|
||||
sendEventFunc("hitl_resumed", "审计 Agent 已通过,继续执行", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"interruptId": p.InterruptID,
|
||||
"toolName": toolName,
|
||||
"comment": ad.Comment,
|
||||
"editedArgs": ad.EditedArguments,
|
||||
"decidedBy": "audit_agent",
|
||||
})
|
||||
}
|
||||
h.hitlManager.TrackApprovedHitlExecution(p.InterruptID, conversationID, toolName, toolCallID)
|
||||
return &ad, nil
|
||||
}
|
||||
|
||||
if sendEventFunc != nil {
|
||||
sendEventFunc("hitl_interrupt", "命中人机协同审批", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
@@ -498,6 +549,7 @@ func (h *AgentHandler) waitHITLApproval(runCtx context.Context, cancelRun contex
|
||||
"editedArgs": d.EditedArguments,
|
||||
})
|
||||
}
|
||||
h.hitlManager.TrackApprovedHitlExecution(p.InterruptID, conversationID, toolName, toolCallID)
|
||||
return &d, nil
|
||||
}
|
||||
|
||||
@@ -527,11 +579,6 @@ func (h *AgentHandler) handleHITLToolCall(runCtx context.Context, cancelRun cont
|
||||
}
|
||||
|
||||
func (h *AgentHandler) ListHITLPending(c *gin.Context) {
|
||||
conversationID := strings.TrimSpace(c.Query("conversationId"))
|
||||
status := strings.TrimSpace(c.Query("status"))
|
||||
if status == "" {
|
||||
status = "pending"
|
||||
}
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
if page < 1 {
|
||||
page = 1
|
||||
@@ -539,15 +586,12 @@ func (h *AgentHandler) ListHITLPending(c *gin.Context) {
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20"))
|
||||
pageSize = int(math.Max(1, math.Min(float64(pageSize), 200)))
|
||||
offset := (page - 1) * pageSize
|
||||
q := `SELECT id, conversation_id, message_id, mode, tool_name, tool_call_id, payload, status, decision, decision_comment, created_at, decided_at FROM hitl_interrupts WHERE 1=1`
|
||||
args := []interface{}{}
|
||||
if conversationID != "" {
|
||||
q += " AND conversation_id = ?"
|
||||
args = append(args, conversationID)
|
||||
}
|
||||
if status != "all" {
|
||||
q += " AND status = ?"
|
||||
args = append(args, status)
|
||||
q, args := h.buildHitlListQuery(false)
|
||||
q, args = h.appendHitlListFilters(q, args, c)
|
||||
total, err := h.countHitlQuery(q, args)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
q += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
||||
args = append(args, pageSize, offset)
|
||||
@@ -557,41 +601,12 @@ func (h *AgentHandler) ListHITLPending(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
items := make([]map[string]interface{}, 0)
|
||||
for rows.Next() {
|
||||
var id, cid, mode, toolName, toolCallID, payload, rowStatus string
|
||||
var messageID sql.NullString
|
||||
var decision, comment sql.NullString
|
||||
var createdAt time.Time
|
||||
var decidedAt sql.NullTime
|
||||
if err := rows.Scan(&id, &cid, &messageID, &mode, &toolName, &toolCallID, &payload, &rowStatus, &decision, &comment, &createdAt, &decidedAt); err != nil {
|
||||
continue
|
||||
}
|
||||
msgID := ""
|
||||
if messageID.Valid {
|
||||
msgID = messageID.String
|
||||
}
|
||||
items = append(items, map[string]interface{}{
|
||||
"id": id,
|
||||
"conversationId": cid,
|
||||
"messageId": msgID,
|
||||
"mode": mode,
|
||||
"toolName": toolName,
|
||||
"toolCallId": toolCallID,
|
||||
"payload": payload,
|
||||
"status": rowStatus,
|
||||
"decision": decision.String,
|
||||
"comment": comment.String,
|
||||
"createdAt": createdAt,
|
||||
"decidedAt": func() interface{} {
|
||||
if decidedAt.Valid {
|
||||
return decidedAt.Time
|
||||
}
|
||||
return nil
|
||||
}(),
|
||||
})
|
||||
items, err := h.scanHitlInterruptRows(rows)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"items": items, "page": page, "pageSize": pageSize})
|
||||
c.JSON(http.StatusOK, gin.H{"items": items, "page": page, "pageSize": pageSize, "total": total})
|
||||
}
|
||||
|
||||
type hitlDecisionReq struct {
|
||||
@@ -636,7 +651,7 @@ func (h *AgentHandler) DismissHITLInterrupt(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
res, err := h.db.Exec(`UPDATE hitl_interrupts SET status='cancelled', decision='reject',
|
||||
decision_comment='dismissed by user', decided_at=CURRENT_TIMESTAMP
|
||||
decision_comment='dismissed by user', decided_at=CURRENT_TIMESTAMP, decided_by='human'
|
||||
WHERE id=? AND status='pending'`, req.InterruptID)
|
||||
if err != nil {
|
||||
c.JSON(500, gin.H{"error": err.Error()})
|
||||
@@ -732,6 +747,7 @@ func (h *AgentHandler) UpsertHITLConversationConfig(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
req.Mode = normalizeHitlMode(req.Mode)
|
||||
req.Reviewer = normalizeHitlReviewer(req.Reviewer)
|
||||
if err := h.hitlManager.SaveConversationConfig(req.ConversationID, &req.HITLRequest); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -753,6 +769,44 @@ type mergeHitlGlobalWhitelistReq struct {
|
||||
SensitiveTools []string `json:"sensitiveTools"`
|
||||
}
|
||||
|
||||
type setHitlGlobalWhitelistReq struct {
|
||||
ToolWhitelist []string `json:"toolWhitelist"`
|
||||
}
|
||||
|
||||
// GetHITLGlobalToolWhitelist 返回 config.yaml 中的全局免审批工具白名单。
|
||||
func (h *AgentHandler) GetHITLGlobalToolWhitelist(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"toolWhitelist": h.hitlConfigGlobalToolWhitelist(),
|
||||
})
|
||||
}
|
||||
|
||||
// SetHITLGlobalToolWhitelist 整表替换 config.yaml 中的全局免审批工具白名单。
|
||||
func (h *AgentHandler) SetHITLGlobalToolWhitelist(c *gin.Context) {
|
||||
if h.hitlWhitelistSaver == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "HITL 配置持久化不可用"})
|
||||
return
|
||||
}
|
||||
var req setHitlGlobalWhitelistReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if err := h.hitlWhitelistSaver.SetHitlToolWhitelist(req.ToolWhitelist); err != nil {
|
||||
h.logger.Warn("写入 HITL 工具白名单到 config.yaml 失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "hitl", "tool_whitelist_update", "HITL 全局白名单更新", "hitl_config", "tool_whitelist", nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"ok": true,
|
||||
"toolWhitelist": h.hitlConfigGlobalToolWhitelist(),
|
||||
"hitlGlobalToolWhitelist": h.hitlConfigGlobalToolWhitelist(),
|
||||
"hitlGlobalWhitelistMerged": false,
|
||||
})
|
||||
}
|
||||
|
||||
// MergeHITLGlobalToolWhitelist 无会话 ID 时将侧栏提交的免审批工具合并进 config.yaml(与 PUT /hitl/config 中白名单落盘规则一致)。
|
||||
func (h *AgentHandler) MergeHITLGlobalToolWhitelist(c *gin.Context) {
|
||||
if h.hitlWhitelistSaver == nil {
|
||||
|
||||
@@ -0,0 +1,357 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// auditAgentReview 在 reviewer=audit_agent 时由 LLM 代行审批。
|
||||
// 白名单工具在 shouldInterrupt 阶段已跳过,到达此处的一律需要裁决。
|
||||
func (h *AgentHandler) auditAgentReview(ctx context.Context, hitlMode, toolName string, payload map[string]interface{}) hitlDecision {
|
||||
if h == nil {
|
||||
return hitlDecision{Decision: "reject", Comment: "audit agent: handler unavailable"}
|
||||
}
|
||||
mode := normalizeHitlMode(hitlMode)
|
||||
prompt := config.DefaultHitlAuditAgentPrompt()
|
||||
if h.config != nil {
|
||||
prompt = h.config.Hitl.EffectiveAuditAgentPromptForMode(mode)
|
||||
}
|
||||
if h.auditLLM == nil {
|
||||
return hitlDecision{Decision: "reject", Comment: "audit agent: LLM 未配置"}
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
callCtx, cancel := context.WithTimeout(ctx, 90*time.Second)
|
||||
defer cancel()
|
||||
|
||||
userContent := buildAuditAgentReviewInput(mode, toolName, payload)
|
||||
requestBody := map[string]interface{}{
|
||||
"model": h.auditLLMModel(),
|
||||
"messages": []map[string]interface{}{
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": userContent},
|
||||
},
|
||||
"temperature": 0.1,
|
||||
"max_completion_tokens": 1024,
|
||||
// 审计裁决需要结构化 JSON;关闭 thinking 避免 Qwen 等把正文放进 reasoning_content 导致解析失败。
|
||||
"thinking": map[string]interface{}{"type": "disabled"},
|
||||
}
|
||||
|
||||
var apiResponse struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
ReasoningContent string `json:"reasoning_content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
if err := h.auditLLM.ChatCompletion(callCtx, requestBody, &apiResponse); err != nil {
|
||||
h.logger.Warn("审计 Agent LLM 调用失败", zap.Error(err), zap.String("tool", toolName))
|
||||
return hitlDecision{
|
||||
Decision: "reject",
|
||||
Comment: "audit agent: LLM 调用失败,保守拒绝",
|
||||
}
|
||||
}
|
||||
if len(apiResponse.Choices) == 0 {
|
||||
return hitlDecision{Decision: "reject", Comment: "audit agent: LLM 无有效响应,保守拒绝"}
|
||||
}
|
||||
msg := apiResponse.Choices[0].Message
|
||||
raw := strings.TrimSpace(msg.Content)
|
||||
if raw == "" {
|
||||
raw = strings.TrimSpace(msg.ReasoningContent)
|
||||
}
|
||||
dec, err := parseAuditAgentLLMContent(raw)
|
||||
if err != nil {
|
||||
snippet := raw
|
||||
if len(snippet) > 240 {
|
||||
snippet = snippet[:240] + "..."
|
||||
}
|
||||
h.logger.Warn("审计 Agent 响应解析失败",
|
||||
zap.Error(err),
|
||||
zap.String("tool", toolName),
|
||||
zap.String("mode", mode),
|
||||
zap.String("snippet", snippet),
|
||||
)
|
||||
return hitlDecision{Decision: "reject", Comment: "audit agent: 响应无法解析,保守拒绝"}
|
||||
}
|
||||
if mode != "review_edit" && len(dec.EditedArguments) > 0 {
|
||||
h.logger.Warn("审计 Agent 在审批模式下返回 editedArguments,已忽略",
|
||||
zap.String("tool", toolName),
|
||||
)
|
||||
dec.EditedArguments = nil
|
||||
}
|
||||
if dec.Comment == "" {
|
||||
dec.Comment = "audit agent: " + dec.Decision
|
||||
} else if !strings.HasPrefix(strings.ToLower(dec.Comment), "audit agent") {
|
||||
dec.Comment = "audit agent: " + dec.Comment
|
||||
}
|
||||
return dec
|
||||
}
|
||||
|
||||
func (h *AgentHandler) auditLLMModel() string {
|
||||
if h.config != nil && strings.TrimSpace(h.config.OpenAI.Model) != "" {
|
||||
return strings.TrimSpace(h.config.OpenAI.Model)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func buildAuditAgentReviewInput(hitlMode, toolName string, payload map[string]interface{}) string {
|
||||
review := map[string]interface{}{
|
||||
"hitlMode": normalizeHitlMode(hitlMode),
|
||||
"toolName": strings.TrimSpace(toolName),
|
||||
}
|
||||
if payload != nil {
|
||||
for _, k := range []string{"arguments", "argumentsObj", "command", hitlPayloadUserMessage, hitlPayloadThinking, hitlPayloadReasoningChain, hitlPayloadPlanning} {
|
||||
if v, ok := payload[k]; ok && v != nil && fmt.Sprint(v) != "" {
|
||||
review[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
b, err := json.MarshalIndent(review, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Sprintf(`{"hitlMode":%q,"toolName":%q}`, normalizeHitlMode(hitlMode), toolName)
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func parseAuditAgentLLMContent(content string) (hitlDecision, error) {
|
||||
s := strings.TrimSpace(content)
|
||||
if s == "" {
|
||||
return hitlDecision{}, errors.New("empty content")
|
||||
}
|
||||
for _, candidate := range auditAgentJSONCandidates(s) {
|
||||
dec, comment, editedArgs, err := parseAuditAgentDecisionObject(candidate)
|
||||
if err == nil {
|
||||
return hitlDecision{
|
||||
Decision: dec,
|
||||
Comment: comment,
|
||||
EditedArguments: editedArgs,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
return hitlDecision{}, fmt.Errorf("no valid decision json in response")
|
||||
}
|
||||
|
||||
func auditAgentJSONCandidates(s string) []string {
|
||||
out := make([]string, 0, 4)
|
||||
seen := make(map[string]struct{})
|
||||
add := func(c string) {
|
||||
c = strings.TrimSpace(c)
|
||||
if c == "" {
|
||||
return
|
||||
}
|
||||
if _, ok := seen[c]; ok {
|
||||
return
|
||||
}
|
||||
seen[c] = struct{}{}
|
||||
out = append(out, c)
|
||||
}
|
||||
add(s)
|
||||
add(stripMarkdownCodeFence(s))
|
||||
if obj := extractFirstJSONObject(s); obj != "" {
|
||||
add(obj)
|
||||
}
|
||||
if obj := extractFirstJSONObject(stripMarkdownCodeFence(s)); obj != "" {
|
||||
add(obj)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func stripMarkdownCodeFence(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
for _, fence := range []string{"```json", "```JSON", "```"} {
|
||||
if strings.HasPrefix(s, fence) {
|
||||
s = strings.TrimPrefix(s, fence)
|
||||
}
|
||||
}
|
||||
s = strings.TrimSuffix(s, "```")
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
|
||||
func extractFirstJSONObject(s string) string {
|
||||
start := strings.Index(s, "{")
|
||||
if start < 0 {
|
||||
return ""
|
||||
}
|
||||
depth := 0
|
||||
inStr := false
|
||||
esc := false
|
||||
for i := start; i < len(s); i++ {
|
||||
ch := s[i]
|
||||
if inStr {
|
||||
if esc {
|
||||
esc = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
esc = true
|
||||
continue
|
||||
}
|
||||
if ch == '"' {
|
||||
inStr = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
switch ch {
|
||||
case '"':
|
||||
inStr = true
|
||||
case '{':
|
||||
depth++
|
||||
case '}':
|
||||
depth--
|
||||
if depth == 0 {
|
||||
return s[start : i+1]
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseAuditAgentDecisionObject(jsonText string) (decision, comment string, editedArgs map[string]interface{}, err error) {
|
||||
var parsed map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(jsonText), &parsed); err != nil {
|
||||
return "", "", nil, err
|
||||
}
|
||||
rawDecision := auditAgentPickString(parsed, "decision", "Decision", "result", "action", "verdict", "决策", "决定")
|
||||
decision = normalizeAuditAgentDecision(rawDecision)
|
||||
if decision == "" {
|
||||
return "", "", nil, fmt.Errorf("missing decision")
|
||||
}
|
||||
comment = auditAgentPickString(parsed, "comment", "Comment", "reason", "message", "rationale", "备注", "理由", "说明")
|
||||
editedArgs = auditAgentPickObject(parsed, "editedArguments", "edited_arguments", "editedArgs")
|
||||
return decision, strings.TrimSpace(comment), editedArgs, nil
|
||||
}
|
||||
|
||||
func auditAgentPickString(m map[string]interface{}, keys ...string) string {
|
||||
for _, k := range keys {
|
||||
if v, ok := m[k]; ok && v != nil {
|
||||
s := strings.TrimSpace(fmt.Sprint(v))
|
||||
if s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func auditAgentPickObject(m map[string]interface{}, keys ...string) map[string]interface{} {
|
||||
for _, k := range keys {
|
||||
v, ok := m[k]
|
||||
if !ok || v == nil {
|
||||
continue
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case map[string]interface{}:
|
||||
if len(t) > 0 {
|
||||
return t
|
||||
}
|
||||
case string:
|
||||
s := strings.TrimSpace(t)
|
||||
if s == "" || s == "{}" {
|
||||
continue
|
||||
}
|
||||
var obj map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(s), &obj); err == nil && len(obj) > 0 {
|
||||
return obj
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeAuditAgentDecision(v string) string {
|
||||
d := strings.ToLower(strings.TrimSpace(v))
|
||||
switch d {
|
||||
case "approve", "approved", "pass", "passed", "allow", "allowed", "yes", "ok", "accept", "accepted":
|
||||
return "approve"
|
||||
case "reject", "rejected", "deny", "denied", "no", "block", "blocked", "refuse", "refused":
|
||||
return "reject"
|
||||
}
|
||||
switch strings.TrimSpace(v) {
|
||||
case "通过", "批准", "允许", "同意", "放行":
|
||||
return "approve"
|
||||
case "拒绝", "驳回", "禁止", "否决":
|
||||
return "reject"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type hitlAuditStrategyReq struct {
|
||||
AuditAgentPrompt string `json:"auditAgentPrompt"`
|
||||
AuditAgentPromptReviewEdit string `json:"auditAgentPromptReviewEdit"`
|
||||
}
|
||||
|
||||
func (h *AgentHandler) GetHITLAuditStrategy(c *gin.Context) {
|
||||
approvalPrompt := config.DefaultHitlAuditAgentPrompt()
|
||||
reviewEditPrompt := config.DefaultHitlAuditAgentPromptReviewEdit()
|
||||
approvalCustom := false
|
||||
reviewEditCustom := false
|
||||
if h.config != nil {
|
||||
approvalPrompt = h.config.Hitl.EffectiveAuditAgentPromptForMode("approval")
|
||||
reviewEditPrompt = h.config.Hitl.EffectiveAuditAgentPromptForMode("review_edit")
|
||||
approvalCustom = strings.TrimSpace(h.config.Hitl.AuditAgentPrompt) != ""
|
||||
reviewEditCustom = strings.TrimSpace(h.config.Hitl.AuditAgentPromptReviewEdit) != ""
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"auditAgentPrompt": approvalPrompt,
|
||||
"auditAgentPromptCustom": approvalCustom,
|
||||
"auditAgentPromptReviewEdit": reviewEditPrompt,
|
||||
"auditAgentPromptReviewEditCustom": reviewEditCustom,
|
||||
"defaultAuditAgentPrompt": config.DefaultHitlAuditAgentPrompt(),
|
||||
"defaultAuditAgentPromptReviewEdit": config.DefaultHitlAuditAgentPromptReviewEdit(),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AgentHandler) UpdateHITLAuditStrategy(c *gin.Context) {
|
||||
if h.hitlStrategySaver == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "HITL 策略持久化不可用"})
|
||||
return
|
||||
}
|
||||
var req hitlAuditStrategyReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
approvalPrompt := strings.TrimSpace(req.AuditAgentPrompt)
|
||||
reviewEditPrompt := strings.TrimSpace(req.AuditAgentPromptReviewEdit)
|
||||
if err := h.hitlStrategySaver.UpdateHitlAuditAgentStrategy(approvalPrompt, reviewEditPrompt); err != nil {
|
||||
h.logger.Warn("保存审计 Agent 提示词失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "hitl", "audit_strategy_update", "HITL 审计策略更新", "hitl_config", "audit_agent_prompt", nil)
|
||||
}
|
||||
if h.config != nil {
|
||||
h.config.Hitl.AuditAgentPrompt = approvalPrompt
|
||||
h.config.Hitl.AuditAgentPromptReviewEdit = reviewEditPrompt
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"ok": true,
|
||||
"auditAgentPrompt": config.HitlConfig{AuditAgentPrompt: approvalPrompt}.EffectiveAuditAgentPromptForMode("approval"),
|
||||
"auditAgentPromptCustom": approvalPrompt != "",
|
||||
"auditAgentPromptReviewEdit": config.HitlConfig{AuditAgentPromptReviewEdit: reviewEditPrompt}.EffectiveAuditAgentPromptForMode("review_edit"),
|
||||
"auditAgentPromptReviewEditCustom": reviewEditPrompt != "",
|
||||
})
|
||||
}
|
||||
|
||||
// HitlAuditStrategySaver 持久化审计 Agent 提示词到 config.yaml。
|
||||
type HitlAuditStrategySaver interface {
|
||||
UpdateHitlAuditAgentStrategy(approvalPrompt, reviewEditPrompt string) error
|
||||
}
|
||||
|
||||
// SetHitlAuditStrategySaver 设置审计策略落盘。
|
||||
func (h *AgentHandler) SetHitlAuditStrategySaver(s HitlAuditStrategySaver) {
|
||||
h.hitlStrategySaver = s
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseAuditAgentLLMContentApprove(t *testing.T) {
|
||||
d, err := parseAuditAgentLLMContent(`{"decision":"approve","comment":"与任务一致"}`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if d.Decision != "approve" || d.Comment != "与任务一致" {
|
||||
t.Fatalf("unexpected %+v", d)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAuditAgentLLMContentReject(t *testing.T) {
|
||||
d, err := parseAuditAgentLLMContent("```json\n{\"decision\":\"reject\",\"comment\":\"风险过高\"}\n```")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if d.Decision != "reject" {
|
||||
t.Fatalf("expected reject, got %s", d.Decision)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAuditAgentLLMContentInvalid(t *testing.T) {
|
||||
_, err := parseAuditAgentLLMContent(`{"decision":"maybe"}`)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid decision")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAuditAgentLLMContentProseWrapped(t *testing.T) {
|
||||
d, err := parseAuditAgentLLMContent("好的,裁决如下:\n```json\n{\"decision\":\"approve\",\"comment\":\"只读 ls\"}\n```\n以上。")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if d.Decision != "approve" {
|
||||
t.Fatalf("expected approve, got %s", d.Decision)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAuditAgentLLMContentChineseDecision(t *testing.T) {
|
||||
d, err := parseAuditAgentLLMContent(`{"decision":"通过","comment":"风险低"}`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if d.Decision != "approve" {
|
||||
t.Fatalf("expected approve, got %s", d.Decision)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAuditAgentLLMContentWithEditedArguments(t *testing.T) {
|
||||
d, err := parseAuditAgentLLMContent(`{"decision":"approve","comment":"收窄路径","editedArguments":{"path":"/safe"}}`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if d.Decision != "approve" {
|
||||
t.Fatalf("expected approve, got %s", d.Decision)
|
||||
}
|
||||
if d.EditedArguments == nil || d.EditedArguments["path"] != "/safe" {
|
||||
t.Fatalf("unexpected edited args: %+v", d.EditedArguments)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuditAgentReviewInputIncludesMode(t *testing.T) {
|
||||
s := buildAuditAgentReviewInput("review_edit", "execute", map[string]interface{}{
|
||||
"arguments": `{"command":"pwd"}`,
|
||||
})
|
||||
if !strings.Contains(s, "review_edit") || !strings.Contains(s, "execute") {
|
||||
t.Fatalf("unexpected input: %s", s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuditAgentReviewInput(t *testing.T) {
|
||||
s := buildAuditAgentReviewInput("approval", "nmap", map[string]interface{}{
|
||||
"arguments": `{"target":"10.0.0.1"}`,
|
||||
"userMessage": "扫描内网",
|
||||
})
|
||||
if s == "" {
|
||||
t.Fatal("expected non-empty input")
|
||||
}
|
||||
if !strings.Contains(s, "nmap") || !strings.Contains(s, "10.0.0.1") || !strings.Contains(s, "扫描内网") {
|
||||
t.Fatalf("unexpected input: %s", s)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
type hitlCognitionState struct {
|
||||
AssistantMessageID string
|
||||
UserMessage string
|
||||
Thinking string
|
||||
ReasoningChain string
|
||||
Planning string
|
||||
}
|
||||
|
||||
// GetHitlCognition 返回当前运行任务上缓存的本轮 HITL 上下文(不含会话历史)。
|
||||
func (m *AgentTaskManager) GetHitlCognition(conversationID string) hitlCognitionFields {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if m == nil || conversationID == "" {
|
||||
return hitlCognitionFields{}
|
||||
}
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
t, ok := m.tasks[conversationID]
|
||||
if !ok || t == nil || t.hitlCognition == nil {
|
||||
return hitlCognitionFields{}
|
||||
}
|
||||
c := t.hitlCognition
|
||||
return hitlCognitionFields{
|
||||
UserMessage: c.UserMessage,
|
||||
Thinking: c.Thinking,
|
||||
ReasoningChain: c.ReasoningChain,
|
||||
Planning: c.Planning,
|
||||
}
|
||||
}
|
||||
|
||||
// ResetHitlCognition 新任务开始时重置本轮 HITL 上下文。
|
||||
func (m *AgentTaskManager) ResetHitlCognition(conversationID, userMessage string) {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if m == nil || conversationID == "" {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
t, ok := m.tasks[conversationID]
|
||||
if !ok || t == nil {
|
||||
return
|
||||
}
|
||||
t.hitlCognition = &hitlCognitionState{UserMessage: strings.TrimSpace(userMessage)}
|
||||
}
|
||||
|
||||
// SetHitlAssistantMessageID 记录当前助手消息 ID,供 HITL 与 DB 回退对齐。
|
||||
func (m *AgentTaskManager) SetHitlAssistantMessageID(conversationID, assistantMessageID string) {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
assistantMessageID = strings.TrimSpace(assistantMessageID)
|
||||
if m == nil || conversationID == "" || assistantMessageID == "" {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
t, ok := m.tasks[conversationID]
|
||||
if !ok || t == nil {
|
||||
return
|
||||
}
|
||||
if t.hitlCognition == nil {
|
||||
t.hitlCognition = &hitlCognitionState{}
|
||||
}
|
||||
t.hitlCognition.AssistantMessageID = assistantMessageID
|
||||
}
|
||||
|
||||
// UpdateHitlCognitionSnapshot 从进行中的进度流快照更新 thinking / reasoning / planning。
|
||||
func (m *AgentTaskManager) UpdateHitlCognitionSnapshot(conversationID, assistantMessageID, thinking, reasoningChain, planning string) {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if m == nil || conversationID == "" {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
t, ok := m.tasks[conversationID]
|
||||
if !ok || t == nil {
|
||||
return
|
||||
}
|
||||
if t.hitlCognition == nil {
|
||||
t.hitlCognition = &hitlCognitionState{}
|
||||
}
|
||||
if id := strings.TrimSpace(assistantMessageID); id != "" {
|
||||
t.hitlCognition.AssistantMessageID = id
|
||||
}
|
||||
if s := strings.TrimSpace(thinking); s != "" {
|
||||
t.hitlCognition.Thinking = s
|
||||
}
|
||||
if s := strings.TrimSpace(reasoningChain); s != "" {
|
||||
t.hitlCognition.ReasoningChain = s
|
||||
}
|
||||
if s := strings.TrimSpace(planning); s != "" {
|
||||
t.hitlCognition.Planning = s
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
hitlPayloadUserMessage = "userMessage"
|
||||
hitlPayloadThinking = "thinking"
|
||||
hitlPayloadReasoningChain = "reasoningChain"
|
||||
hitlPayloadPlanning = "planning"
|
||||
)
|
||||
|
||||
type hitlCognitionFields struct {
|
||||
UserMessage string
|
||||
Thinking string
|
||||
ReasoningChain string
|
||||
Planning string
|
||||
}
|
||||
|
||||
func (h *AgentHandler) enrichHitlApprovalPayload(conversationID, assistantMessageID string, payload map[string]interface{}) {
|
||||
if h == nil || payload == nil {
|
||||
return
|
||||
}
|
||||
cog := h.collectHitlCognition(conversationID, assistantMessageID)
|
||||
if s := strings.TrimSpace(cog.UserMessage); s != "" {
|
||||
payload[hitlPayloadUserMessage] = s
|
||||
}
|
||||
if s := strings.TrimSpace(cog.Thinking); s != "" {
|
||||
payload[hitlPayloadThinking] = s
|
||||
}
|
||||
if s := strings.TrimSpace(cog.ReasoningChain); s != "" {
|
||||
payload[hitlPayloadReasoningChain] = s
|
||||
}
|
||||
if s := strings.TrimSpace(cog.Planning); s != "" {
|
||||
payload[hitlPayloadPlanning] = s
|
||||
}
|
||||
}
|
||||
|
||||
func (h *AgentHandler) collectHitlCognition(conversationID, assistantMessageID string) hitlCognitionFields {
|
||||
var out hitlCognitionFields
|
||||
if h.tasks != nil {
|
||||
out = h.tasks.GetHitlCognition(conversationID)
|
||||
}
|
||||
if strings.TrimSpace(out.UserMessage) == "" && h.db != nil {
|
||||
if msg, err := h.db.GetTurnUserMessage(conversationID, assistantMessageID); err == nil {
|
||||
out.UserMessage = msg
|
||||
}
|
||||
}
|
||||
if h.db != nil && assistantMessageID != "" {
|
||||
dbCog, err := h.db.GetAssistantCognitionTexts(assistantMessageID)
|
||||
if err == nil {
|
||||
if strings.TrimSpace(out.Thinking) == "" {
|
||||
out.Thinking = dbCog.Thinking
|
||||
}
|
||||
if strings.TrimSpace(out.ReasoningChain) == "" {
|
||||
out.ReasoningChain = dbCog.ReasoningChain
|
||||
}
|
||||
if strings.TrimSpace(out.Planning) == "" {
|
||||
out.Planning = dbCog.Planning
|
||||
}
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func snapshotHitlCognitionFromStreams(thinkingStreams map[string]*thinkingBuf, respPlan *responsePlanAgg) (thinking, reasoningChain, planning string) {
|
||||
if len(thinkingStreams) > 0 {
|
||||
var thinkingParts, reasoningParts []string
|
||||
for _, tb := range thinkingStreams {
|
||||
if tb == nil {
|
||||
continue
|
||||
}
|
||||
content := strings.TrimSpace(tb.b.String())
|
||||
if content == "" {
|
||||
continue
|
||||
}
|
||||
if tb.persistAs == "reasoning_chain" {
|
||||
reasoningParts = append(reasoningParts, content)
|
||||
} else {
|
||||
thinkingParts = append(thinkingParts, content)
|
||||
}
|
||||
}
|
||||
thinking = strings.Join(thinkingParts, "\n\n")
|
||||
reasoningChain = strings.Join(reasoningParts, "\n\n")
|
||||
}
|
||||
if respPlan != nil {
|
||||
planning = strings.TrimSpace(respPlan.b.String())
|
||||
}
|
||||
return thinking, reasoningChain, planning
|
||||
}
|
||||
|
||||
func (h *AgentHandler) syncHitlCognitionFromProgress(conversationID, assistantMessageID string, thinkingStreams map[string]*thinkingBuf, respPlan *responsePlanAgg) {
|
||||
if h == nil || h.tasks == nil {
|
||||
return
|
||||
}
|
||||
thinking, reasoning, planning := snapshotHitlCognitionFromStreams(thinkingStreams, respPlan)
|
||||
if thinking == "" && reasoning == "" && planning == "" {
|
||||
return
|
||||
}
|
||||
h.tasks.UpdateHitlCognitionSnapshot(conversationID, assistantMessageID, thinking, reasoning, planning)
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestEnrichHitlApprovalPayload(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
db, err := database.NewDB(filepath.Join(tmp, "test.sqlite"), zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("db: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmp)
|
||||
|
||||
conv, err := db.CreateConversation("hitl ctx", database.ConversationCreateMeta{})
|
||||
if err != nil {
|
||||
t.Fatalf("conv: %v", err)
|
||||
}
|
||||
if _, err := db.AddMessage(conv.ID, "user", "scan 10.0.0.1 please", nil); err != nil {
|
||||
t.Fatalf("user msg: %v", err)
|
||||
}
|
||||
asst, err := db.AddMessage(conv.ID, "assistant", "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("asst msg: %v", err)
|
||||
}
|
||||
if err := db.AddProcessDetail(asst.ID, conv.ID, "thinking", "need port scan first", nil); err != nil {
|
||||
t.Fatalf("detail: %v", err)
|
||||
}
|
||||
|
||||
h := &AgentHandler{db: db, tasks: NewAgentTaskManager()}
|
||||
payload := map[string]interface{}{"toolName": "nmap", "arguments": "{}"}
|
||||
h.enrichHitlApprovalPayload(conv.ID, asst.ID, payload)
|
||||
|
||||
if got := payload["userMessage"]; got != "scan 10.0.0.1 please" {
|
||||
t.Fatalf("userMessage=%v", got)
|
||||
}
|
||||
if got := payload["thinking"]; got != "need port scan first" {
|
||||
t.Fatalf("thinking=%v", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const hitlPayloadExecutionResult = "executionResult"
|
||||
|
||||
type hitlExecutionResult struct {
|
||||
Success bool `json:"success"`
|
||||
Result string `json:"result,omitempty"`
|
||||
ToolName string `json:"toolName,omitempty"`
|
||||
ToolCallID string `json:"toolCallId,omitempty"`
|
||||
RecordedAt time.Time `json:"recordedAt"`
|
||||
}
|
||||
|
||||
type hitlApprovedExecTrack struct {
|
||||
InterruptID string
|
||||
ConversationID string
|
||||
ToolName string
|
||||
ToolCallID string
|
||||
}
|
||||
|
||||
// TrackApprovedHitlExecution 审批通过后登记,待 tool_result 回写执行结果。
|
||||
func (m *HITLManager) TrackApprovedHitlExecution(interruptID, conversationID, toolName, toolCallID string) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
interruptID = strings.TrimSpace(interruptID)
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if interruptID == "" || conversationID == "" {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.approvedExec == nil {
|
||||
m.approvedExec = make(map[string][]hitlApprovedExecTrack)
|
||||
}
|
||||
m.approvedExec[conversationID] = append(m.approvedExec[conversationID], hitlApprovedExecTrack{
|
||||
InterruptID: interruptID,
|
||||
ConversationID: conversationID,
|
||||
ToolName: strings.TrimSpace(toolName),
|
||||
ToolCallID: strings.TrimSpace(toolCallID),
|
||||
})
|
||||
}
|
||||
|
||||
func (m *HITLManager) popApprovedInterruptForTool(conversationID, toolCallID, toolName string) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
toolCallID = strings.TrimSpace(toolCallID)
|
||||
toolName = strings.TrimSpace(toolName)
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
queue := m.approvedExec[conversationID]
|
||||
if len(queue) == 0 {
|
||||
return ""
|
||||
}
|
||||
idx := -1
|
||||
if toolCallID != "" {
|
||||
for i, t := range queue {
|
||||
if t.ToolCallID == toolCallID {
|
||||
idx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if idx < 0 && toolName != "" {
|
||||
for i, t := range queue {
|
||||
if strings.EqualFold(t.ToolName, toolName) {
|
||||
idx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if idx < 0 {
|
||||
return ""
|
||||
}
|
||||
id := queue[idx].InterruptID
|
||||
queue = append(queue[:idx], queue[idx+1:]...)
|
||||
if len(queue) == 0 {
|
||||
delete(m.approvedExec, conversationID)
|
||||
} else {
|
||||
m.approvedExec[conversationID] = queue
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func mergeHitlPayloadExecutionResult(payloadJSON string, exec hitlExecutionResult) (string, error) {
|
||||
root := make(map[string]interface{})
|
||||
if strings.TrimSpace(payloadJSON) != "" {
|
||||
_ = json.Unmarshal([]byte(payloadJSON), &root)
|
||||
}
|
||||
if root == nil {
|
||||
root = make(map[string]interface{})
|
||||
}
|
||||
root[hitlPayloadExecutionResult] = exec
|
||||
out, err := json.Marshal(root)
|
||||
if err != nil {
|
||||
return payloadJSON, err
|
||||
}
|
||||
return string(out), nil
|
||||
}
|
||||
|
||||
func (h *AgentHandler) recordHitlToolExecutionResult(conversationID, toolCallID, toolName string, success bool, result string) {
|
||||
if h == nil || h.hitlManager == nil || h.db == nil {
|
||||
return
|
||||
}
|
||||
interruptID := h.hitlManager.popApprovedInterruptForTool(conversationID, toolCallID, toolName)
|
||||
if interruptID == "" {
|
||||
return
|
||||
}
|
||||
var payloadJSON string
|
||||
err := h.db.QueryRow(`SELECT payload FROM hitl_interrupts WHERE id = ?`, interruptID).Scan(&payloadJSON)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
merged, err := mergeHitlPayloadExecutionResult(payloadJSON, hitlExecutionResult{
|
||||
Success: success,
|
||||
Result: strings.TrimSpace(result),
|
||||
ToolName: strings.TrimSpace(toolName),
|
||||
ToolCallID: strings.TrimSpace(toolCallID),
|
||||
RecordedAt: time.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = h.db.Exec(`UPDATE hitl_interrupts SET payload = ? WHERE id = ?`, merged, interruptID)
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMergeHitlPayloadExecutionResult(t *testing.T) {
|
||||
merged, err := mergeHitlPayloadExecutionResult(`{"userMessage":"hi","toolName":"nmap"}`, hitlExecutionResult{
|
||||
Success: true,
|
||||
Result: "open ports: 80",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var root map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(merged), &root); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if root["userMessage"] != "hi" {
|
||||
t.Fatalf("userMessage lost: %v", root["userMessage"])
|
||||
}
|
||||
exec, ok := root["executionResult"].(map[string]interface{})
|
||||
if !ok || exec["success"] != true {
|
||||
t.Fatalf("executionResult missing: %v", root["executionResult"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPopApprovedInterruptForTool(t *testing.T) {
|
||||
m := NewHITLManager(nil, nil)
|
||||
m.TrackApprovedHitlExecution("hitl_a", "conv1", "nmap", "tc1")
|
||||
m.TrackApprovedHitlExecution("hitl_b", "conv1", "exec", "")
|
||||
if id := m.popApprovedInterruptForTool("conv1", "tc1", "nmap"); id != "hitl_a" {
|
||||
t.Fatalf("tc1 match=%q", id)
|
||||
}
|
||||
if id := m.popApprovedInterruptForTool("conv1", "", "exec"); id != "hitl_b" {
|
||||
t.Fatalf("tool name match=%q", id)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,201 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"math"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func normalizeHitlReviewer(v string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(v)) {
|
||||
case "audit_agent", "agent", "ai":
|
||||
return "audit_agent"
|
||||
default:
|
||||
return "human"
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeHitlDecidedBy(v string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(v)) {
|
||||
case "audit_agent", "agent", "ai":
|
||||
return "audit_agent"
|
||||
case "system", "timeout":
|
||||
return "system"
|
||||
case "manual":
|
||||
return "manual"
|
||||
default:
|
||||
return "human"
|
||||
}
|
||||
}
|
||||
|
||||
func (m *HITLManager) migrateHitlSchemaColumns() {
|
||||
_, _ = m.db.Exec(`ALTER TABLE hitl_interrupts ADD COLUMN decided_by TEXT NOT NULL DEFAULT 'human'`)
|
||||
_, _ = m.db.Exec(`ALTER TABLE hitl_conversation_configs ADD COLUMN reviewer TEXT NOT NULL DEFAULT 'human'`)
|
||||
}
|
||||
|
||||
func hitlInterruptRowToMap(
|
||||
id, cid, mode, toolName, toolCallID, payload, rowStatus, decidedBy string,
|
||||
messageID sql.NullString,
|
||||
decision, comment sql.NullString,
|
||||
createdAt time.Time,
|
||||
decidedAt sql.NullTime,
|
||||
) map[string]interface{} {
|
||||
msgID := ""
|
||||
if messageID.Valid {
|
||||
msgID = messageID.String
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"id": id,
|
||||
"conversationId": cid,
|
||||
"messageId": msgID,
|
||||
"mode": mode,
|
||||
"toolName": toolName,
|
||||
"toolCallId": toolCallID,
|
||||
"payload": payload,
|
||||
"status": rowStatus,
|
||||
"decision": decision.String,
|
||||
"comment": comment.String,
|
||||
"decidedBy": decidedBy,
|
||||
"createdAt": createdAt,
|
||||
"decidedAt": func() interface{} {
|
||||
if decidedAt.Valid {
|
||||
return decidedAt.Time
|
||||
}
|
||||
return nil
|
||||
}(),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *AgentHandler) buildHitlListQuery(logs bool) (string, []interface{}) {
|
||||
q := `SELECT id, conversation_id, message_id, mode, tool_name, tool_call_id, payload, status, decision, decision_comment, COALESCE(decided_by,'human'), created_at, decided_at FROM hitl_interrupts WHERE 1=1`
|
||||
args := []interface{}{}
|
||||
if logs {
|
||||
q += " AND status != 'pending'"
|
||||
} else {
|
||||
q += " AND status = 'pending'"
|
||||
}
|
||||
return q, args
|
||||
}
|
||||
|
||||
func (h *AgentHandler) appendHitlListFilters(q string, args []interface{}, c *gin.Context) (string, []interface{}) {
|
||||
conversationID := strings.TrimSpace(c.Query("conversationId"))
|
||||
toolName := strings.TrimSpace(c.Query("toolName"))
|
||||
decision := strings.TrimSpace(c.Query("decision"))
|
||||
decidedBy := strings.TrimSpace(c.Query("decidedBy"))
|
||||
status := strings.TrimSpace(c.Query("status"))
|
||||
search := strings.TrimSpace(c.Query("q"))
|
||||
|
||||
if conversationID != "" {
|
||||
q += " AND conversation_id = ?"
|
||||
args = append(args, conversationID)
|
||||
}
|
||||
if toolName != "" {
|
||||
q += " AND tool_name LIKE ?"
|
||||
args = append(args, "%"+toolName+"%")
|
||||
}
|
||||
if decision != "" && decision != "all" {
|
||||
q += " AND decision = ?"
|
||||
args = append(args, decision)
|
||||
}
|
||||
if decidedBy != "" && decidedBy != "all" {
|
||||
q += " AND COALESCE(decided_by,'human') = ?"
|
||||
args = append(args, normalizeHitlDecidedBy(decidedBy))
|
||||
}
|
||||
if status != "" && status != "all" {
|
||||
q += " AND status = ?"
|
||||
args = append(args, status)
|
||||
}
|
||||
if search != "" {
|
||||
like := "%" + search + "%"
|
||||
q += " AND (id LIKE ? OR conversation_id LIKE ? OR tool_name LIKE ? OR payload LIKE ? OR COALESCE(decision_comment,'') LIKE ?)"
|
||||
args = append(args, like, like, like, like, like)
|
||||
}
|
||||
return q, args
|
||||
}
|
||||
|
||||
func (h *AgentHandler) scanHitlInterruptRows(rows *sql.Rows) ([]map[string]interface{}, error) {
|
||||
items := make([]map[string]interface{}, 0)
|
||||
for rows.Next() {
|
||||
var id, cid, mode, toolName, toolCallID, payload, rowStatus, decidedBy string
|
||||
var messageID sql.NullString
|
||||
var decision, comment sql.NullString
|
||||
var createdAt time.Time
|
||||
var decidedAt sql.NullTime
|
||||
if err := rows.Scan(&id, &cid, &messageID, &mode, &toolName, &toolCallID, &payload, &rowStatus, &decision, &comment, &decidedBy, &createdAt, &decidedAt); err != nil {
|
||||
continue
|
||||
}
|
||||
items = append(items, hitlInterruptRowToMap(id, cid, mode, toolName, toolCallID, payload, rowStatus, decidedBy, messageID, decision, comment, createdAt, decidedAt))
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (h *AgentHandler) countHitlQuery(baseQ string, args []interface{}) (int, error) {
|
||||
countQ := "SELECT COUNT(*) FROM (" + baseQ + ") AS hitl_cnt"
|
||||
var total int
|
||||
if err := h.db.QueryRow(countQ, args...).Scan(&total); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func (h *AgentHandler) ListHITLLogs(c *gin.Context) {
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20"))
|
||||
pageSize = int(math.Max(1, math.Min(float64(pageSize), 200)))
|
||||
offset := (page - 1) * pageSize
|
||||
|
||||
q, args := h.buildHitlListQuery(true)
|
||||
q, args = h.appendHitlListFilters(q, args, c)
|
||||
total, err := h.countHitlQuery(q, args)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
q += " ORDER BY COALESCE(decided_at, created_at) DESC LIMIT ? OFFSET ?"
|
||||
args = append(args, pageSize, offset)
|
||||
rows, err := h.db.Query(q, args...)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
items, err := h.scanHitlInterruptRows(rows)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"items": items, "page": page, "pageSize": pageSize, "total": total})
|
||||
}
|
||||
|
||||
func (h *AgentHandler) GetHITLLog(c *gin.Context) {
|
||||
id := strings.TrimSpace(c.Param("id"))
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"})
|
||||
return
|
||||
}
|
||||
q := `SELECT id, conversation_id, message_id, mode, tool_name, tool_call_id, payload, status, decision, decision_comment, COALESCE(decided_by,'human'), created_at, decided_at FROM hitl_interrupts WHERE id = ?`
|
||||
var rowID, cid, mode, toolName, toolCallID, payload, rowStatus, decidedBy string
|
||||
var messageID sql.NullString
|
||||
var decision, comment sql.NullString
|
||||
var createdAt time.Time
|
||||
var decidedAt sql.NullTime
|
||||
err := h.db.QueryRow(q, id).Scan(&rowID, &cid, &messageID, &mode, &toolName, &toolCallID, &payload, &rowStatus, &decision, &comment, &decidedBy, &createdAt, &decidedAt)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "not found"})
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, hitlInterruptRowToMap(rowID, cid, mode, toolName, toolCallID, payload, rowStatus, decidedBy, messageID, decision, comment, createdAt, decidedAt))
|
||||
}
|
||||
@@ -43,6 +43,9 @@ type AgentTask struct {
|
||||
// activeEinoExecuteAbortNote AbortActiveEinoExecute 写入的用户说明,由 execute 收尾时合并进工具结果
|
||||
activeEinoExecuteAbortNote string
|
||||
|
||||
// hitlCognition 本轮运行中供 HITL/审计 Agent 读取的上下文(用户原话 + 思考,不含会话历史)
|
||||
hitlCognition *hitlCognitionState
|
||||
|
||||
cancel func(error)
|
||||
}
|
||||
|
||||
@@ -354,6 +357,7 @@ func (m *AgentTaskManager) StartTask(conversationID, message string, cancel cont
|
||||
}
|
||||
|
||||
m.tasks[conversationID] = task
|
||||
task.hitlCognition = &hitlCognitionState{UserMessage: strings.TrimSpace(message)}
|
||||
return task, nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user