mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-06-30 17:55:32 +02:00
264 lines
8.2 KiB
Go
264 lines
8.2 KiB
Go
package handler
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"math"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"cyberstrike-ai/internal/config"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
func normalizeHitlReviewer(v string) string {
|
|
switch strings.ToLower(strings.TrimSpace(v)) {
|
|
case "audit_agent", "agent", "ai":
|
|
return "audit_agent"
|
|
default:
|
|
return "human"
|
|
}
|
|
}
|
|
|
|
func normalizeHitlDecidedBy(v string) string {
|
|
switch strings.ToLower(strings.TrimSpace(v)) {
|
|
case "audit_agent", "agent", "ai":
|
|
return "audit_agent"
|
|
case "system", "timeout":
|
|
return "system"
|
|
case "manual":
|
|
return "manual"
|
|
default:
|
|
return "human"
|
|
}
|
|
}
|
|
|
|
func (m *HITLManager) migrateHitlSchemaColumns() {
|
|
_, _ = m.db.Exec(`ALTER TABLE hitl_interrupts ADD COLUMN decided_by TEXT NOT NULL DEFAULT 'human'`)
|
|
_, _ = m.db.Exec(`ALTER TABLE hitl_conversation_configs ADD COLUMN reviewer TEXT NOT NULL DEFAULT 'human'`)
|
|
}
|
|
|
|
func hitlInterruptRowToMap(
|
|
id, cid, mode, toolName, toolCallID, payload, rowStatus, decidedBy string,
|
|
messageID sql.NullString,
|
|
decision, comment sql.NullString,
|
|
createdAt time.Time,
|
|
decidedAt sql.NullTime,
|
|
) map[string]interface{} {
|
|
msgID := ""
|
|
if messageID.Valid {
|
|
msgID = messageID.String
|
|
}
|
|
return map[string]interface{}{
|
|
"id": id,
|
|
"conversationId": cid,
|
|
"messageId": msgID,
|
|
"mode": mode,
|
|
"toolName": toolName,
|
|
"toolCallId": toolCallID,
|
|
"payload": payload,
|
|
"status": rowStatus,
|
|
"decision": decision.String,
|
|
"comment": comment.String,
|
|
"decidedBy": decidedBy,
|
|
"createdAt": createdAt,
|
|
"decidedAt": func() interface{} {
|
|
if decidedAt.Valid {
|
|
return decidedAt.Time
|
|
}
|
|
return nil
|
|
}(),
|
|
}
|
|
}
|
|
|
|
func (h *AgentHandler) buildHitlListQuery(logs bool) (string, []interface{}) {
|
|
where, args := h.buildHitlLogsWhere(logs)
|
|
q := `SELECT id, conversation_id, message_id, mode, tool_name, tool_call_id, payload, status, decision, decision_comment, COALESCE(decided_by,'human'), created_at, decided_at FROM hitl_interrupts` + where
|
|
return q, args
|
|
}
|
|
|
|
func (h *AgentHandler) buildHitlLogsWhere(logs bool) (string, []interface{}) {
|
|
q := " WHERE 1=1"
|
|
args := []interface{}{}
|
|
if logs {
|
|
q += " AND status != 'pending'"
|
|
} else {
|
|
q += " AND status = 'pending'"
|
|
}
|
|
return q, args
|
|
}
|
|
|
|
func (h *AgentHandler) appendHitlListFilters(q string, args []interface{}, c *gin.Context) (string, []interface{}) {
|
|
conversationID := strings.TrimSpace(c.Query("conversationId"))
|
|
toolName := strings.TrimSpace(c.Query("toolName"))
|
|
decision := strings.TrimSpace(c.Query("decision"))
|
|
decidedBy := strings.TrimSpace(c.Query("decidedBy"))
|
|
status := strings.TrimSpace(c.Query("status"))
|
|
search := strings.TrimSpace(c.Query("q"))
|
|
|
|
if conversationID != "" {
|
|
q += " AND conversation_id = ?"
|
|
args = append(args, conversationID)
|
|
}
|
|
if toolName != "" {
|
|
q += " AND tool_name LIKE ?"
|
|
args = append(args, "%"+toolName+"%")
|
|
}
|
|
if decision != "" && decision != "all" {
|
|
q += " AND decision = ?"
|
|
args = append(args, decision)
|
|
}
|
|
if decidedBy != "" && decidedBy != "all" {
|
|
q += " AND COALESCE(decided_by,'human') = ?"
|
|
args = append(args, normalizeHitlDecidedBy(decidedBy))
|
|
}
|
|
if status != "" && status != "all" {
|
|
q += " AND status = ?"
|
|
args = append(args, status)
|
|
}
|
|
if search != "" {
|
|
like := "%" + search + "%"
|
|
q += " AND (id LIKE ? OR conversation_id LIKE ? OR tool_name LIKE ? OR payload LIKE ? OR COALESCE(decision_comment,'') LIKE ?)"
|
|
args = append(args, like, like, like, like, like)
|
|
}
|
|
return q, args
|
|
}
|
|
|
|
func (h *AgentHandler) scanHitlInterruptRows(rows *sql.Rows) ([]map[string]interface{}, error) {
|
|
items := make([]map[string]interface{}, 0)
|
|
for rows.Next() {
|
|
var id, cid, mode, toolName, toolCallID, payload, rowStatus, decidedBy string
|
|
var messageID sql.NullString
|
|
var decision, comment sql.NullString
|
|
var createdAt time.Time
|
|
var decidedAt sql.NullTime
|
|
if err := rows.Scan(&id, &cid, &messageID, &mode, &toolName, &toolCallID, &payload, &rowStatus, &decision, &comment, &decidedBy, &createdAt, &decidedAt); err != nil {
|
|
continue
|
|
}
|
|
items = append(items, hitlInterruptRowToMap(id, cid, mode, toolName, toolCallID, payload, rowStatus, decidedBy, messageID, decision, comment, createdAt, decidedAt))
|
|
}
|
|
return items, nil
|
|
}
|
|
|
|
func (h *AgentHandler) countHitlQuery(baseQ string, args []interface{}) (int, error) {
|
|
countQ := "SELECT COUNT(*) FROM (" + baseQ + ") AS hitl_cnt"
|
|
var total int
|
|
if err := h.db.QueryRow(countQ, args...).Scan(&total); err != nil {
|
|
return 0, err
|
|
}
|
|
return total, nil
|
|
}
|
|
|
|
func (h *AgentHandler) ListHITLLogs(c *gin.Context) {
|
|
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
|
if page < 1 {
|
|
page = 1
|
|
}
|
|
pageSize, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20"))
|
|
pageSize = int(math.Max(1, math.Min(float64(pageSize), 200)))
|
|
offset := (page - 1) * pageSize
|
|
|
|
q, args := h.buildHitlListQuery(true)
|
|
q, args = h.appendHitlListFilters(q, args, c)
|
|
total, err := h.countHitlQuery(q, args)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
q += " ORDER BY COALESCE(decided_at, created_at) DESC LIMIT ? OFFSET ?"
|
|
args = append(args, pageSize, offset)
|
|
rows, err := h.db.Query(q, args...)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
items, err := h.scanHitlInterruptRows(rows)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, gin.H{"items": items, "page": page, "pageSize": pageSize, "total": total, "retentionDays": h.hitlRetentionDays()})
|
|
}
|
|
|
|
func (h *AgentHandler) hitlRetentionDays() int {
|
|
if h.config != nil {
|
|
return h.config.Hitl.RetentionDaysEffective()
|
|
}
|
|
return config.HitlConfig{}.RetentionDaysEffective()
|
|
}
|
|
|
|
// DeleteHITLLogs 批量删除或按筛选清空已决策的人机协同审计日志(不删除 pending)。
|
|
func (h *AgentHandler) DeleteHITLLogs(c *gin.Context) {
|
|
var request struct {
|
|
IDs []string `json:"ids"`
|
|
All bool `json:"all"`
|
|
}
|
|
if err := c.ShouldBindJSON(&request); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效: " + err.Error()})
|
|
return
|
|
}
|
|
|
|
var deleted int64
|
|
var err error
|
|
if request.All {
|
|
where, args := h.buildHitlLogsWhere(true)
|
|
where, args = h.appendHitlListFilters(where, args, c)
|
|
deleted, err = h.db.DeleteHitlInterruptLogsMatching(where, args)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
if h.audit != nil {
|
|
h.audit.RecordOK(c, "hitl", "logs_clear", "清空人机协同审计日志", "hitl_interrupt", "", map[string]interface{}{
|
|
"deleted": deleted,
|
|
})
|
|
}
|
|
} else {
|
|
if len(request.IDs) == 0 {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "审计日志 ID 列表不能为空"})
|
|
return
|
|
}
|
|
deleted, err = h.db.DeleteHitlInterruptLogsByIDs(request.IDs)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
if h.audit != nil {
|
|
h.audit.RecordOK(c, "hitl", "logs_delete_batch", "批量删除人机协同审计日志", "hitl_interrupt", "", map[string]interface{}{
|
|
"count": len(request.IDs),
|
|
"deleted": deleted,
|
|
})
|
|
}
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{"message": "删除成功", "deleted": deleted})
|
|
}
|
|
|
|
func (h *AgentHandler) GetHITLLog(c *gin.Context) {
|
|
id := strings.TrimSpace(c.Param("id"))
|
|
if id == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"})
|
|
return
|
|
}
|
|
q := `SELECT id, conversation_id, message_id, mode, tool_name, tool_call_id, payload, status, decision, decision_comment, COALESCE(decided_by,'human'), created_at, decided_at FROM hitl_interrupts WHERE id = ?`
|
|
var rowID, cid, mode, toolName, toolCallID, payload, rowStatus, decidedBy string
|
|
var messageID sql.NullString
|
|
var decision, comment sql.NullString
|
|
var createdAt time.Time
|
|
var decidedAt sql.NullTime
|
|
err := h.db.QueryRow(q, id).Scan(&rowID, &cid, &messageID, &mode, &toolName, &toolCallID, &payload, &rowStatus, &decision, &comment, &decidedBy, &createdAt, &decidedAt)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
c.JSON(http.StatusNotFound, gin.H{"error": "not found"})
|
|
return
|
|
}
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, hitlInterruptRowToMap(rowID, cid, mode, toolName, toolCallID, payload, rowStatus, decidedBy, messageID, decision, comment, createdAt, decidedAt))
|
|
}
|