Add files via upload

This commit is contained in:
公明
2025-11-08 18:56:23 +08:00
committed by GitHub
commit add33e1cf7
24 changed files with 5228 additions and 0 deletions
+134
View File
@@ -0,0 +1,134 @@
package handler
import (
"net/http"
"time"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// AgentHandler Agent处理器
type AgentHandler struct {
agent *agent.Agent
db *database.DB
logger *zap.Logger
}
// NewAgentHandler 创建新的Agent处理器
func NewAgentHandler(agent *agent.Agent, db *database.DB, logger *zap.Logger) *AgentHandler {
return &AgentHandler{
agent: agent,
db: db,
logger: logger,
}
}
// ChatRequest 聊天请求
type ChatRequest struct {
Message string `json:"message" binding:"required"`
ConversationID string `json:"conversationId,omitempty"`
}
// ChatResponse 聊天响应
type ChatResponse struct {
Response string `json:"response"`
MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"` // 本次对话中执行的MCP调用ID列表
ConversationID string `json:"conversationId"` // 对话ID
Time time.Time `json:"time"`
}
// AgentLoop 处理Agent Loop请求
func (h *AgentHandler) AgentLoop(c *gin.Context) {
var req ChatRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
h.logger.Info("收到Agent Loop请求",
zap.String("message", req.Message),
zap.String("conversationId", req.ConversationID),
)
// 如果没有对话ID,创建新对话
conversationID := req.ConversationID
if conversationID == "" {
title := req.Message
if len(title) > 50 {
title = title[:50] + "..."
}
conv, err := h.db.CreateConversation(title)
if err != nil {
h.logger.Error("创建对话失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
conversationID = conv.ID
}
// 获取历史消息(排除当前消息,因为还没保存)
historyMessages, err := h.db.GetMessages(conversationID)
if err != nil {
h.logger.Warn("获取历史消息失败", zap.Error(err))
historyMessages = []database.Message{}
}
h.logger.Info("获取历史消息",
zap.String("conversationId", conversationID),
zap.Int("count", len(historyMessages)),
)
// 将数据库消息转换为Agent消息格式
agentHistoryMessages := make([]agent.ChatMessage, 0, len(historyMessages))
for i, msg := range historyMessages {
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
Role: msg.Role,
Content: msg.Content,
})
contentPreview := msg.Content
if len(contentPreview) > 50 {
contentPreview = contentPreview[:50] + "..."
}
h.logger.Info("添加历史消息",
zap.Int("index", i),
zap.String("role", msg.Role),
zap.String("content", contentPreview),
)
}
h.logger.Info("历史消息转换完成",
zap.Int("originalCount", len(historyMessages)),
zap.Int("convertedCount", len(agentHistoryMessages)),
)
// 保存用户消息
_, err = h.db.AddMessage(conversationID, "user", req.Message, nil)
if err != nil {
h.logger.Error("保存用户消息失败", zap.Error(err))
}
// 执行Agent Loop,传入历史消息
result, err := h.agent.AgentLoop(c.Request.Context(), req.Message, agentHistoryMessages)
if err != nil {
h.logger.Error("Agent Loop执行失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 保存助手回复
_, err = h.db.AddMessage(conversationID, "assistant", result.Response, result.MCPExecutionIDs)
if err != nil {
h.logger.Error("保存助手消息失败", zap.Error(err))
}
c.JSON(http.StatusOK, ChatResponse{
Response: result.Response,
MCPExecutionIDs: result.MCPExecutionIDs,
ConversationID: conversationID,
Time: time.Now(),
})
}
+102
View File
@@ -0,0 +1,102 @@
package handler
import (
"net/http"
"strconv"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// ConversationHandler 对话处理器
type ConversationHandler struct {
db *database.DB
logger *zap.Logger
}
// NewConversationHandler 创建新的对话处理器
func NewConversationHandler(db *database.DB, logger *zap.Logger) *ConversationHandler {
return &ConversationHandler{
db: db,
logger: logger,
}
}
// CreateConversationRequest 创建对话请求
type CreateConversationRequest struct {
Title string `json:"title"`
}
// CreateConversation 创建新对话
func (h *ConversationHandler) CreateConversation(c *gin.Context) {
var req CreateConversationRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
title := req.Title
if title == "" {
title = "新对话"
}
conv, err := h.db.CreateConversation(title)
if err != nil {
h.logger.Error("创建对话失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, conv)
}
// ListConversations 列出对话
func (h *ConversationHandler) ListConversations(c *gin.Context) {
limitStr := c.DefaultQuery("limit", "50")
offsetStr := c.DefaultQuery("offset", "0")
limit, _ := strconv.Atoi(limitStr)
offset, _ := strconv.Atoi(offsetStr)
if limit <= 0 || limit > 100 {
limit = 50
}
conversations, err := h.db.ListConversations(limit, offset)
if err != nil {
h.logger.Error("获取对话列表失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, conversations)
}
// GetConversation 获取对话
func (h *ConversationHandler) GetConversation(c *gin.Context) {
id := c.Param("id")
conv, err := h.db.GetConversation(id)
if err != nil {
h.logger.Error("获取对话失败", zap.Error(err))
c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"})
return
}
c.JSON(http.StatusOK, conv)
}
// DeleteConversation 删除对话
func (h *ConversationHandler) DeleteConversation(c *gin.Context) {
id := c.Param("id")
if err := h.db.DeleteConversation(id); err != nil {
h.logger.Error("删除对话失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
}
+92
View File
@@ -0,0 +1,92 @@
package handler
import (
"net/http"
"time"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/security"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// MonitorHandler 监控处理器
type MonitorHandler struct {
mcpServer *mcp.Server
executor *security.Executor
logger *zap.Logger
vulns []security.Vulnerability
}
// NewMonitorHandler 创建新的监控处理器
func NewMonitorHandler(mcpServer *mcp.Server, executor *security.Executor, logger *zap.Logger) *MonitorHandler {
return &MonitorHandler{
mcpServer: mcpServer,
executor: executor,
logger: logger,
vulns: []security.Vulnerability{},
}
}
// MonitorResponse 监控响应
type MonitorResponse struct {
Executions []*mcp.ToolExecution `json:"executions"`
Stats map[string]*mcp.ToolStats `json:"stats"`
Vulnerabilities []security.Vulnerability `json:"vulnerabilities"`
Report map[string]interface{} `json:"report"`
Timestamp time.Time `json:"timestamp"`
}
// Monitor 获取监控信息
func (h *MonitorHandler) Monitor(c *gin.Context) {
// 获取所有执行记录
executions := h.mcpServer.GetAllExecutions()
// 分析执行结果,提取漏洞
for _, exec := range executions {
if exec.Status == "completed" && exec.Result != nil {
vulns := h.executor.AnalyzeResults(exec.ToolName, exec.Result)
h.vulns = append(h.vulns, vulns...)
}
}
// 获取统计信息
stats := h.mcpServer.GetStats()
// 生成报告
report := h.executor.GetVulnerabilityReport(h.vulns)
c.JSON(http.StatusOK, MonitorResponse{
Executions: executions,
Stats: stats,
Vulnerabilities: h.vulns,
Report: report,
Timestamp: time.Now(),
})
}
// GetExecution 获取特定执行记录
func (h *MonitorHandler) GetExecution(c *gin.Context) {
id := c.Param("id")
exec, exists := h.mcpServer.GetExecution(id)
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "执行记录未找到"})
return
}
c.JSON(http.StatusOK, exec)
}
// GetStats 获取统计信息
func (h *MonitorHandler) GetStats(c *gin.Context) {
stats := h.mcpServer.GetStats()
c.JSON(http.StatusOK, stats)
}
// GetVulnerabilities 获取漏洞列表
func (h *MonitorHandler) GetVulnerabilities(c *gin.Context) {
report := h.executor.GetVulnerabilityReport(h.vulns)
c.JSON(http.StatusOK, report)
}