diff --git a/internal/app/app.go b/internal/app/app.go index e39dfbad..5a2cef85 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -96,6 +96,7 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error } auditSvc := audit.NewService(db, cfg, log.Logger) + audit.RegisterConversationCreateHook(auditSvc) auditSvc.PurgeExpired() audit.StartRetentionLoop(auditSvc, log.Logger) diff --git a/internal/audit/conversation_create.go b/internal/audit/conversation_create.go new file mode 100644 index 00000000..82e19b54 --- /dev/null +++ b/internal/audit/conversation_create.go @@ -0,0 +1,55 @@ +package audit + +import ( + "strings" + + "cyberstrike-ai/internal/database" + "cyberstrike-ai/internal/security" + + "github.com/gin-gonic/gin" +) + +// RegisterConversationCreateHook records platform audit rows for every new conversation. +func RegisterConversationCreateHook(s *Service) { + if s == nil { + return + } + database.SetConversationCreateHook(func(conv *database.Conversation, meta database.ConversationCreateMeta) { + detail := map[string]interface{}{ + "title": conv.Title, + "source": meta.Source, + } + if meta.WebShellConnectionID != "" { + detail["webshell_connection_id"] = meta.WebShellConnectionID + } + s.Record(nil, Entry{ + Category: "conversation", + Action: "create", + Result: "success", + Message: "创建对话", + ResourceType: "conversation", + ResourceID: conv.ID, + Detail: detail, + ClientIP: meta.ClientIP, + SessionHint: meta.SessionHint, + }) + }) +} + +// ConversationCreateMeta builds audit metadata for conversation creation. +func ConversationCreateMeta(source string) database.ConversationCreateMeta { + return database.ConversationCreateMeta{Source: strings.TrimSpace(source)} +} + +// ConversationCreateMetaFromGin includes client IP and session hint when available. +func ConversationCreateMetaFromGin(c *gin.Context, source string) database.ConversationCreateMeta { + m := ConversationCreateMeta(source) + if c == nil { + return m + } + m.ClientIP = c.ClientIP() + if token := c.GetString(security.ContextAuthTokenKey); token != "" { + m.SessionHint = sessionHint(token) + } + return m +} diff --git a/internal/audit/resource_availability.go b/internal/audit/resource_availability.go new file mode 100644 index 00000000..3b22871f --- /dev/null +++ b/internal/audit/resource_availability.go @@ -0,0 +1,86 @@ +package audit + +import ( + "strings" + + "cyberstrike-ai/internal/database" +) + +var auditActionsResourceRemoved = map[string]bool{ + "delete": true, + "item_delete": true, + "connection_delete": true, + "listener_delete": true, + "session_delete": true, + "task_delete": true, + "execution_delete": true, + "execution_delete_batch": true, + "delete_queue": true, + "delete_batch_task": true, + "markdown_delete": true, +} + +// ApplyResourceAvailability sets log.ResourceAvailable when the linked resource can be checked. +func ApplyResourceAvailability(db *database.DB, log *database.AuditLog) { + if log == nil || strings.TrimSpace(log.ResourceID) == "" { + return + } + if auditActionsResourceRemoved[log.Action] { + f := false + log.ResourceAvailable = &f + return + } + if db == nil { + return + } + available, known := resourceStillExists(db, log.ResourceType, log.ResourceID) + if known { + log.ResourceAvailable = &available + } +} + +func resourceStillExists(db *database.DB, resourceType, resourceID string) (bool, bool) { + resourceID = strings.TrimSpace(resourceID) + if resourceID == "" { + return false, false + } + t := strings.TrimSpace(resourceType) + if t == "" { + if len(resourceID) > 8 && !strings.HasPrefix(resourceID, "c2_") { + t = "conversation" + } else { + return false, false + } + } + switch t { + case "conversation": + ok, err := db.ConversationExists(resourceID) + return ok, err == nil + case "vulnerability": + _, err := db.GetVulnerability(resourceID) + if err != nil { + return false, strings.Contains(err.Error(), "不存在") + } + return true, true + case "batch_queue": + _, err := db.GetBatchQueue(resourceID) + return err == nil, true + case "c2_listener": + _, err := db.GetC2Listener(resourceID) + return err == nil, true + case "c2_session": + _, err := db.GetC2Session(resourceID) + return err == nil, true + case "c2_task": + _, err := db.GetC2Task(resourceID) + return err == nil, true + case "webshell_connection": + c, err := db.GetWebshellConnection(resourceID) + return err == nil && c != nil, true + case "tool_execution": + _, err := db.GetToolExecution(resourceID) + return err == nil, true + default: + return false, false + } +} diff --git a/internal/audit/retention.go b/internal/audit/retention.go index 83ef05d1..f882595c 100644 --- a/internal/audit/retention.go +++ b/internal/audit/retention.go @@ -6,13 +6,16 @@ import ( "go.uber.org/zap" ) +// auditRetentionPurgeInterval is how often PurgeExpired runs while the process is up (startup also purges once). +const auditRetentionPurgeInterval = time.Hour + // StartRetentionLoop periodically purges expired audit rows. func StartRetentionLoop(s *Service, logger *zap.Logger) { if s == nil { return } go func() { - ticker := time.NewTicker(24 * time.Hour) + ticker := time.NewTicker(auditRetentionPurgeInterval) defer ticker.Stop() for range ticker.C { s.PurgeExpired() diff --git a/internal/audit/service.go b/internal/audit/service.go index eb537f33..a6cc1203 100644 --- a/internal/audit/service.go +++ b/internal/audit/service.go @@ -65,14 +65,20 @@ func (s *Service) Record(c *gin.Context, e Entry) { if strings.TrimSpace(e.Actor) == "" { e.Actor = "admin" } - if e.SessionHint == "" && c != nil { - if token := c.GetString(security.ContextAuthTokenKey); token != "" { - e.SessionHint = sessionHint(token) - } - } maxDetail := s.cfg.Audit.MaxDetailBytesEffective() detail := SanitizeDetail(e.Detail, maxDetail) + sessionHintVal := e.SessionHint + if sessionHintVal == "" && c != nil { + if token := c.GetString(security.ContextAuthTokenKey); token != "" { + sessionHintVal = sessionHint(token) + } + } + clientIPVal := e.ClientIP + if clientIPVal == "" { + clientIPVal = clientIP(c) + } + row := &database.AuditLog{ ID: "audit_" + strings.ReplaceAll(uuid.New().String(), "-", ""), CreatedAt: time.Now(), @@ -81,8 +87,8 @@ func (s *Service) Record(c *gin.Context, e Entry) { Action: e.Action, Result: e.Result, Actor: e.Actor, - SessionHint: e.SessionHint, - ClientIP: clientIP(c), + SessionHint: sessionHintVal, + ClientIP: clientIPVal, UserAgent: userAgent(c), ResourceType: e.ResourceType, ResourceID: e.ResourceID, diff --git a/internal/audit/types.go b/internal/audit/types.go index 7876e2f7..ff83ea58 100644 --- a/internal/audit/types.go +++ b/internal/audit/types.go @@ -11,5 +11,6 @@ type Entry struct { ResourceType string ResourceID string Message string - Detail map[string]interface{} + Detail map[string]interface{} + ClientIP string // optional when c is nil (robot, batch, DB hook) } diff --git a/internal/handler/agent.go b/internal/handler/agent.go index e285e7f7..119220b1 100644 --- a/internal/handler/agent.go +++ b/internal/handler/agent.go @@ -560,7 +560,7 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) { conversationID := req.ConversationID if conversationID == "" { title := safeTruncateString(req.Message, 50) - conv, err := h.db.CreateConversation(title) + conv, err := h.db.CreateConversation(title, audit.ConversationCreateMetaFromGin(c, "agent_loop")) if err != nil { h.logger.Error("创建对话失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -725,10 +725,14 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) { } // ProcessMessageForRobot 供机器人(企业微信/钉钉/飞书)调用:与 /api/agent-loop/stream 相同执行路径(含 progressCallback、过程详情),仅不发送 SSE,最后返回完整回复 -func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationID, message, role string) (response string, convID string, err error) { +func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, platform, conversationID, message, role string) (response string, convID string, err error) { if conversationID == "" { title := safeTruncateString(message, 50) - conv, createErr := h.db.CreateConversation(title) + src := "robot" + if strings.TrimSpace(platform) != "" { + src = "robot:" + strings.TrimSpace(platform) + } + conv, createErr := h.db.CreateConversation(title, audit.ConversationCreateMeta(src)) if createErr != nil { return "", "", fmt.Errorf("创建对话失败: %w", createErr) } @@ -1427,10 +1431,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) { title := safeTruncateString(req.Message, 50) var conv *database.Conversation var err error + meta := audit.ConversationCreateMetaFromGin(c, "agent_loop_stream") if req.WebShellConnectionID != "" { - conv, err = h.db.CreateConversationWithWebshell(strings.TrimSpace(req.WebShellConnectionID), title) + meta.Source = "webshell_chat" + conv, err = h.db.CreateConversationWithWebshell(strings.TrimSpace(req.WebShellConnectionID), title, meta) } else { - conv, err = h.db.CreateConversation(title) + conv, err = h.db.CreateConversation(title, meta) } if err != nil { h.logger.Error("创建对话失败", zap.Error(err)) @@ -2559,7 +2565,7 @@ func (h *AgentHandler) executeBatchQueue(queueID string) { // 创建新对话 title := safeTruncateString(task.Message, 50) - conv, err := h.db.CreateConversation(title) + conv, err := h.db.CreateConversation(title, audit.ConversationCreateMeta("batch_task")) var conversationID string if err != nil { h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) diff --git a/internal/handler/audit.go b/internal/handler/audit.go index 5649324b..7cb4dd47 100644 --- a/internal/handler/audit.go +++ b/internal/handler/audit.go @@ -116,6 +116,7 @@ func (h *AuditHandler) GetLog(c *gin.Context) { c.JSON(http.StatusNotFound, gin.H{"error": "审计记录不存在"}) return } + audit.ApplyResourceAvailability(h.db, row) c.JSON(http.StatusOK, gin.H{"log": row}) } diff --git a/internal/handler/conversation.go b/internal/handler/conversation.go index 840b31e0..e3e62c98 100644 --- a/internal/handler/conversation.go +++ b/internal/handler/conversation.go @@ -49,7 +49,7 @@ func (h *ConversationHandler) CreateConversation(c *gin.Context) { title = "新对话" } - conv, err := h.db.CreateConversation(title) + conv, err := h.db.CreateConversation(title, audit.ConversationCreateMetaFromGin(c, "api")) if err != nil { h.logger.Error("创建对话失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) diff --git a/internal/handler/eino_single_agent.go b/internal/handler/eino_single_agent.go index 8ffd757e..d51a9cfe 100644 --- a/internal/handler/eino_single_agent.go +++ b/internal/handler/eino_single_agent.go @@ -90,7 +90,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) { zap.String("conversationId", req.ConversationID), ) - prep, err := h.prepareMultiAgentSession(&req) + prep, err := h.prepareMultiAgentSession(&req, c, "eino_agent_stream") if err != nil { sendEvent("error", err.Error(), nil) sendEvent("done", "", nil) @@ -326,7 +326,7 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) { h.logger.Info("收到 Eino ADK 单代理非流式请求", zap.String("conversationId", req.ConversationID)) - prep, err := h.prepareMultiAgentSession(&req) + prep, err := h.prepareMultiAgentSession(&req, c, "eino_agent") if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return diff --git a/internal/handler/multi_agent.go b/internal/handler/multi_agent.go index 142a7755..8a707186 100644 --- a/internal/handler/multi_agent.go +++ b/internal/handler/multi_agent.go @@ -107,7 +107,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) { zap.String("conversationId", req.ConversationID), ) - prep, err := h.prepareMultiAgentSession(&req) + prep, err := h.prepareMultiAgentSession(&req, c, "multi_agent_stream") if err != nil { sendEvent("error", err.Error(), nil) sendEvent("done", "", nil) @@ -347,7 +347,7 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) { h.logger.Info("收到 Eino DeepAgent 非流式请求", zap.String("conversationId", req.ConversationID)) - prep, err := h.prepareMultiAgentSession(&req) + prep, err := h.prepareMultiAgentSession(&req, c, "multi_agent") if err != nil { status, msg := multiAgentHTTPErrorStatus(err) c.JSON(status, gin.H{"error": msg}) diff --git a/internal/handler/multi_agent_prepare.go b/internal/handler/multi_agent_prepare.go index 0d35ee7c..3ce2e042 100644 --- a/internal/handler/multi_agent_prepare.go +++ b/internal/handler/multi_agent_prepare.go @@ -5,9 +5,11 @@ import ( "strings" "cyberstrike-ai/internal/agent" + "cyberstrike-ai/internal/audit" "cyberstrike-ai/internal/database" "cyberstrike-ai/internal/mcp/builtin" + "github.com/gin-gonic/gin" "go.uber.org/zap" ) @@ -22,7 +24,7 @@ type multiAgentPrepared struct { UserMessageID string } -func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPrepared, error) { +func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest, c *gin.Context, source string) (*multiAgentPrepared, error) { if len(req.Attachments) > maxAttachments { return nil, fmt.Errorf("附件最多 %d 个", maxAttachments) } @@ -33,10 +35,13 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPr title := safeTruncateString(req.Message, 50) var conv *database.Conversation var err error + meta := audit.ConversationCreateMetaFromGin(c, source) if strings.TrimSpace(req.WebShellConnectionID) != "" { - conv, err = h.db.CreateConversationWithWebshell(strings.TrimSpace(req.WebShellConnectionID), title) + meta.Source = source + "_webshell" + meta.WebShellConnectionID = strings.TrimSpace(req.WebShellConnectionID) + conv, err = h.db.CreateConversationWithWebshell(meta.WebShellConnectionID, title, meta) } else { - conv, err = h.db.CreateConversation(title) + conv, err = h.db.CreateConversation(title, meta) } if err != nil { return nil, fmt.Errorf("创建对话失败: %w", err) diff --git a/internal/handler/robot.go b/internal/handler/robot.go index 37bbf311..2f4aa8de 100644 --- a/internal/handler/robot.go +++ b/internal/handler/robot.go @@ -133,7 +133,7 @@ func (h *RobotHandler) getOrCreateConversation(platform, userID, title string) ( } else { t = safeTruncateString(t, 50) } - conv, err := h.db.CreateConversation(t) + conv, err := h.db.CreateConversation(t, database.ConversationCreateMeta{Source: "robot:" + platform}) if err != nil { h.logger.Warn("创建机器人会话失败", zap.Error(err)) return "", false @@ -188,7 +188,7 @@ func (h *RobotHandler) setRole(platform, userID, roleName string) { // clearConversation 清空当前会话(切换到新对话) func (h *RobotHandler) clearConversation(platform, userID string) (newConvID string) { title := "新对话 " + time.Now().Format("01-02 15:04") - conv, err := h.db.CreateConversation(title) + conv, err := h.db.CreateConversation(title, database.ConversationCreateMeta{Source: "robot:" + platform + ":new"}) if err != nil { h.logger.Warn("创建新对话失败", zap.Error(err)) return "" @@ -242,7 +242,7 @@ func (h *RobotHandler) HandleMessage(platform, userID, text string) (reply strin h.cancelMu.Unlock() }() role := h.getRole(platform, userID) - resp, newConvID, err := h.agentHandler.ProcessMessageForRobot(ctx, convID, text, role) + resp, newConvID, err := h.agentHandler.ProcessMessageForRobot(ctx, platform, convID, text, role) if err != nil { h.logger.Warn("机器人 Agent 执行失败", zap.String("platform", platform), zap.String("userID", userID), zap.Error(err)) if errors.Is(err, context.Canceled) {