mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-05-19 22:38:56 +02:00
Add files via upload
This commit is contained in:
@@ -195,10 +195,21 @@ func (db *DB) DeleteGroup(id string) error {
|
||||
}
|
||||
|
||||
// AddConversationToGroup 将对话添加到分组
|
||||
// 注意:一个对话只能属于一个分组,所以在添加新分组之前,会先删除该对话的所有旧分组关联
|
||||
func (db *DB) AddConversationToGroup(conversationID, groupID string) error {
|
||||
id := uuid.New().String()
|
||||
// 先删除该对话的所有旧分组关联,确保一个对话只属于一个分组
|
||||
_, err := db.Exec(
|
||||
"INSERT OR REPLACE INTO conversation_group_mappings (id, conversation_id, group_id, created_at) VALUES (?, ?, ?, ?)",
|
||||
"DELETE FROM conversation_group_mappings WHERE conversation_id = ?",
|
||||
conversationID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除对话旧分组关联失败: %w", err)
|
||||
}
|
||||
|
||||
// 然后插入新的分组关联
|
||||
id := uuid.New().String()
|
||||
_, err = db.Exec(
|
||||
"INSERT INTO conversation_group_mappings (id, conversation_id, group_id, created_at) VALUES (?, ?, ?, ?)",
|
||||
id, conversationID, groupID, time.Now(),
|
||||
)
|
||||
if err != nil {
|
||||
|
||||
@@ -90,7 +90,7 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
|
||||
}
|
||||
|
||||
// ListVulnerabilities 列出漏洞
|
||||
func (db *DB) ListVulnerabilities(limit, offset int, conversationID, severity, status string) ([]*Vulnerability, error) {
|
||||
func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severity, status string) ([]*Vulnerability, error) {
|
||||
query := `
|
||||
SELECT id, conversation_id, title, description, severity, status,
|
||||
vulnerability_type, target, proof, impact, recommendation,
|
||||
@@ -100,6 +100,10 @@ func (db *DB) ListVulnerabilities(limit, offset int, conversationID, severity, s
|
||||
`
|
||||
args := []interface{}{}
|
||||
|
||||
if id != "" {
|
||||
query += " AND id = ?"
|
||||
args = append(args, id)
|
||||
}
|
||||
if conversationID != "" {
|
||||
query += " AND conversation_id = ?"
|
||||
args = append(args, conversationID)
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
@@ -16,6 +17,47 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// safeTruncateString 安全截断字符串,避免在 UTF-8 字符中间截断
|
||||
func safeTruncateString(s string, maxLen int) string {
|
||||
if maxLen <= 0 {
|
||||
return ""
|
||||
}
|
||||
if utf8.RuneCountInString(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
|
||||
// 将字符串转换为 rune 切片以正确计算字符数
|
||||
runes := []rune(s)
|
||||
if len(runes) <= maxLen {
|
||||
return s
|
||||
}
|
||||
|
||||
// 截断到最大长度
|
||||
truncated := string(runes[:maxLen])
|
||||
|
||||
// 尝试在标点符号或空格处截断,使截断更自然
|
||||
// 在截断点往前查找合适的断点(不超过20%的长度)
|
||||
searchRange := maxLen / 5
|
||||
if searchRange > maxLen {
|
||||
searchRange = maxLen
|
||||
}
|
||||
breakChars := []rune(",。、 ,.;:!?!?/\\-_")
|
||||
bestBreakPos := len(runes[:maxLen])
|
||||
|
||||
for i := bestBreakPos - 1; i >= bestBreakPos-searchRange && i >= 0; i-- {
|
||||
for _, breakChar := range breakChars {
|
||||
if runes[i] == breakChar {
|
||||
bestBreakPos = i + 1 // 在标点符号后断开
|
||||
goto found
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
found:
|
||||
truncated = string(runes[:bestBreakPos])
|
||||
return truncated + "..."
|
||||
}
|
||||
|
||||
// AgentHandler Agent处理器
|
||||
type AgentHandler struct {
|
||||
agent *agent.Agent
|
||||
@@ -74,10 +116,7 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
// 如果没有对话ID,创建新对话
|
||||
conversationID := req.ConversationID
|
||||
if conversationID == "" {
|
||||
title := req.Message
|
||||
if len(title) > 50 {
|
||||
title = title[:50] + "..."
|
||||
}
|
||||
title := safeTruncateString(req.Message, 50)
|
||||
conv, err := h.db.CreateConversation(title)
|
||||
if err != nil {
|
||||
h.logger.Error("创建对话失败", zap.Error(err))
|
||||
@@ -237,10 +276,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
// 如果没有对话ID,创建新对话
|
||||
conversationID := req.ConversationID
|
||||
if conversationID == "" {
|
||||
title := req.Message
|
||||
if len(title) > 50 {
|
||||
title = title[:50] + "..."
|
||||
}
|
||||
title := safeTruncateString(req.Message, 50)
|
||||
conv, err := h.db.CreateConversation(title)
|
||||
if err != nil {
|
||||
h.logger.Error("创建对话失败", zap.Error(err))
|
||||
|
||||
@@ -86,6 +86,7 @@ func (h *VulnerabilityHandler) GetVulnerability(c *gin.Context) {
|
||||
func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
|
||||
limitStr := c.DefaultQuery("limit", "50")
|
||||
offsetStr := c.DefaultQuery("offset", "0")
|
||||
id := c.Query("id")
|
||||
conversationID := c.Query("conversation_id")
|
||||
severity := c.Query("severity")
|
||||
status := c.Query("status")
|
||||
@@ -97,7 +98,7 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
|
||||
limit = 50
|
||||
}
|
||||
|
||||
vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, conversationID, severity, status)
|
||||
vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, id, conversationID, severity, status)
|
||||
if err != nil {
|
||||
h.logger.Error("获取漏洞列表失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
|
||||
Reference in New Issue
Block a user