mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-05-28 02:02:27 +02:00
Add files via upload
This commit is contained in:
@@ -0,0 +1,134 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// AgentHandler Agent处理器
|
||||
type AgentHandler struct {
|
||||
agent *agent.Agent
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewAgentHandler 创建新的Agent处理器
|
||||
func NewAgentHandler(agent *agent.Agent, db *database.DB, logger *zap.Logger) *AgentHandler {
|
||||
return &AgentHandler{
|
||||
agent: agent,
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ChatRequest 聊天请求
|
||||
type ChatRequest struct {
|
||||
Message string `json:"message" binding:"required"`
|
||||
ConversationID string `json:"conversationId,omitempty"`
|
||||
}
|
||||
|
||||
// ChatResponse 聊天响应
|
||||
type ChatResponse struct {
|
||||
Response string `json:"response"`
|
||||
MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"` // 本次对话中执行的MCP调用ID列表
|
||||
ConversationID string `json:"conversationId"` // 对话ID
|
||||
Time time.Time `json:"time"`
|
||||
}
|
||||
|
||||
// AgentLoop 处理Agent Loop请求
|
||||
func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
var req ChatRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("收到Agent Loop请求",
|
||||
zap.String("message", req.Message),
|
||||
zap.String("conversationId", req.ConversationID),
|
||||
)
|
||||
|
||||
// 如果没有对话ID,创建新对话
|
||||
conversationID := req.ConversationID
|
||||
if conversationID == "" {
|
||||
title := req.Message
|
||||
if len(title) > 50 {
|
||||
title = title[:50] + "..."
|
||||
}
|
||||
conv, err := h.db.CreateConversation(title)
|
||||
if err != nil {
|
||||
h.logger.Error("创建对话失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
conversationID = conv.ID
|
||||
}
|
||||
|
||||
// 获取历史消息(排除当前消息,因为还没保存)
|
||||
historyMessages, err := h.db.GetMessages(conversationID)
|
||||
if err != nil {
|
||||
h.logger.Warn("获取历史消息失败", zap.Error(err))
|
||||
historyMessages = []database.Message{}
|
||||
}
|
||||
|
||||
h.logger.Info("获取历史消息",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.Int("count", len(historyMessages)),
|
||||
)
|
||||
|
||||
// 将数据库消息转换为Agent消息格式
|
||||
agentHistoryMessages := make([]agent.ChatMessage, 0, len(historyMessages))
|
||||
for i, msg := range historyMessages {
|
||||
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
})
|
||||
contentPreview := msg.Content
|
||||
if len(contentPreview) > 50 {
|
||||
contentPreview = contentPreview[:50] + "..."
|
||||
}
|
||||
h.logger.Info("添加历史消息",
|
||||
zap.Int("index", i),
|
||||
zap.String("role", msg.Role),
|
||||
zap.String("content", contentPreview),
|
||||
)
|
||||
}
|
||||
|
||||
h.logger.Info("历史消息转换完成",
|
||||
zap.Int("originalCount", len(historyMessages)),
|
||||
zap.Int("convertedCount", len(agentHistoryMessages)),
|
||||
)
|
||||
|
||||
// 保存用户消息
|
||||
_, err = h.db.AddMessage(conversationID, "user", req.Message, nil)
|
||||
if err != nil {
|
||||
h.logger.Error("保存用户消息失败", zap.Error(err))
|
||||
}
|
||||
|
||||
// 执行Agent Loop,传入历史消息
|
||||
result, err := h.agent.AgentLoop(c.Request.Context(), req.Message, agentHistoryMessages)
|
||||
if err != nil {
|
||||
h.logger.Error("Agent Loop执行失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 保存助手回复
|
||||
_, err = h.db.AddMessage(conversationID, "assistant", result.Response, result.MCPExecutionIDs)
|
||||
if err != nil {
|
||||
h.logger.Error("保存助手消息失败", zap.Error(err))
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, ChatResponse{
|
||||
Response: result.Response,
|
||||
MCPExecutionIDs: result.MCPExecutionIDs,
|
||||
ConversationID: conversationID,
|
||||
Time: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,102 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ConversationHandler 对话处理器
|
||||
type ConversationHandler struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewConversationHandler 创建新的对话处理器
|
||||
func NewConversationHandler(db *database.DB, logger *zap.Logger) *ConversationHandler {
|
||||
return &ConversationHandler{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateConversationRequest 创建对话请求
|
||||
type CreateConversationRequest struct {
|
||||
Title string `json:"title"`
|
||||
}
|
||||
|
||||
// CreateConversation 创建新对话
|
||||
func (h *ConversationHandler) CreateConversation(c *gin.Context) {
|
||||
var req CreateConversationRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
title := req.Title
|
||||
if title == "" {
|
||||
title = "新对话"
|
||||
}
|
||||
|
||||
conv, err := h.db.CreateConversation(title)
|
||||
if err != nil {
|
||||
h.logger.Error("创建对话失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, conv)
|
||||
}
|
||||
|
||||
// ListConversations 列出对话
|
||||
func (h *ConversationHandler) ListConversations(c *gin.Context) {
|
||||
limitStr := c.DefaultQuery("limit", "50")
|
||||
offsetStr := c.DefaultQuery("offset", "0")
|
||||
|
||||
limit, _ := strconv.Atoi(limitStr)
|
||||
offset, _ := strconv.Atoi(offsetStr)
|
||||
|
||||
if limit <= 0 || limit > 100 {
|
||||
limit = 50
|
||||
}
|
||||
|
||||
conversations, err := h.db.ListConversations(limit, offset)
|
||||
if err != nil {
|
||||
h.logger.Error("获取对话列表失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, conversations)
|
||||
}
|
||||
|
||||
// GetConversation 获取对话
|
||||
func (h *ConversationHandler) GetConversation(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
conv, err := h.db.GetConversation(id)
|
||||
if err != nil {
|
||||
h.logger.Error("获取对话失败", zap.Error(err))
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, conv)
|
||||
}
|
||||
|
||||
// DeleteConversation 删除对话
|
||||
func (h *ConversationHandler) DeleteConversation(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
if err := h.db.DeleteConversation(id); err != nil {
|
||||
h.logger.Error("删除对话失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/security"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// MonitorHandler 监控处理器
|
||||
type MonitorHandler struct {
|
||||
mcpServer *mcp.Server
|
||||
executor *security.Executor
|
||||
logger *zap.Logger
|
||||
vulns []security.Vulnerability
|
||||
}
|
||||
|
||||
// NewMonitorHandler 创建新的监控处理器
|
||||
func NewMonitorHandler(mcpServer *mcp.Server, executor *security.Executor, logger *zap.Logger) *MonitorHandler {
|
||||
return &MonitorHandler{
|
||||
mcpServer: mcpServer,
|
||||
executor: executor,
|
||||
logger: logger,
|
||||
vulns: []security.Vulnerability{},
|
||||
}
|
||||
}
|
||||
|
||||
// MonitorResponse 监控响应
|
||||
type MonitorResponse struct {
|
||||
Executions []*mcp.ToolExecution `json:"executions"`
|
||||
Stats map[string]*mcp.ToolStats `json:"stats"`
|
||||
Vulnerabilities []security.Vulnerability `json:"vulnerabilities"`
|
||||
Report map[string]interface{} `json:"report"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// Monitor 获取监控信息
|
||||
func (h *MonitorHandler) Monitor(c *gin.Context) {
|
||||
// 获取所有执行记录
|
||||
executions := h.mcpServer.GetAllExecutions()
|
||||
|
||||
// 分析执行结果,提取漏洞
|
||||
for _, exec := range executions {
|
||||
if exec.Status == "completed" && exec.Result != nil {
|
||||
vulns := h.executor.AnalyzeResults(exec.ToolName, exec.Result)
|
||||
h.vulns = append(h.vulns, vulns...)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取统计信息
|
||||
stats := h.mcpServer.GetStats()
|
||||
|
||||
// 生成报告
|
||||
report := h.executor.GetVulnerabilityReport(h.vulns)
|
||||
|
||||
c.JSON(http.StatusOK, MonitorResponse{
|
||||
Executions: executions,
|
||||
Stats: stats,
|
||||
Vulnerabilities: h.vulns,
|
||||
Report: report,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
// GetExecution 获取特定执行记录
|
||||
func (h *MonitorHandler) GetExecution(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
exec, exists := h.mcpServer.GetExecution(id)
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "执行记录未找到"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, exec)
|
||||
}
|
||||
|
||||
// GetStats 获取统计信息
|
||||
func (h *MonitorHandler) GetStats(c *gin.Context) {
|
||||
stats := h.mcpServer.GetStats()
|
||||
c.JSON(http.StatusOK, stats)
|
||||
}
|
||||
|
||||
// GetVulnerabilities 获取漏洞列表
|
||||
func (h *MonitorHandler) GetVulnerabilities(c *gin.Context) {
|
||||
report := h.executor.GetVulnerabilityReport(h.vulns)
|
||||
c.JSON(http.StatusOK, report)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user