Add files via upload

This commit is contained in:
公明
2026-05-27 11:45:51 +08:00
committed by GitHub
parent 6bb3a73f73
commit d913695303
7 changed files with 408 additions and 31 deletions
+146 -29
View File
@@ -6,6 +6,7 @@ import (
"strings"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/project"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
@@ -29,12 +30,13 @@ type createProjectRequest struct {
Status string `json:"status"`
}
// updateProjectRequest 部分更新:字段省略表示不修改;传 null 或 "" 可清空字符串字段。
type updateProjectRequest struct {
Name string `json:"name"`
Description string `json:"description"`
ScopeJSON string `json:"scope_json"`
Status string `json:"status"`
Pinned *bool `json:"pinned"`
Name *string `json:"name"`
Description *string `json:"description"`
ScopeJSON *string `json:"scope_json"`
Status *string `json:"status"`
Pinned *bool `json:"pinned"`
}
// CreateProject POST /api/projects
@@ -75,6 +77,46 @@ func (h *ProjectHandler) ListProjects(c *gin.Context) {
c.JSON(http.StatusOK, list)
}
// GetProjectStats GET /api/projects/:id/stats
func (h *ProjectHandler) GetProjectStats(c *gin.Context) {
stats, err := project.GetProjectStats(h.db, c.Param("id"))
if err != nil {
if strings.Contains(err.Error(), "不存在") {
c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, stats)
}
// ListProjectConversations GET /api/projects/:id/conversations
func (h *ProjectHandler) ListProjectConversations(c *gin.Context) {
projectID := c.Param("id")
if _, err := h.db.GetProject(projectID); err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"})
return
}
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100"))
offset, _ := strconv.Atoi(c.Query("offset"))
list, err := h.db.ListConversationsByProjectID(projectID, limit, offset)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if list == nil {
list = []*database.Conversation{}
}
total, _ := h.db.CountConversationsByProjectID(projectID)
c.JSON(http.StatusOK, gin.H{
"conversations": list,
"total": total,
"limit": limit,
"offset": offset,
})
}
// GetProject GET /api/projects/:id
func (h *ProjectHandler) GetProject(c *gin.Context) {
p, err := h.db.GetProject(c.Param("id"))
@@ -98,17 +140,21 @@ func (h *ProjectHandler) UpdateProject(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if s := strings.TrimSpace(req.Name); s != "" {
p.Name = s
if req.Name != nil {
if s := strings.TrimSpace(*req.Name); s != "" {
p.Name = s
}
}
if req.Description != "" || c.Request.ContentLength > 0 {
p.Description = req.Description
if req.Description != nil {
p.Description = *req.Description
}
if req.ScopeJSON != "" || c.GetHeader("Content-Type") != "" {
p.ScopeJSON = req.ScopeJSON
if req.ScopeJSON != nil {
p.ScopeJSON = *req.ScopeJSON
}
if s := strings.TrimSpace(req.Status); s != "" {
p.Status = s
if req.Status != nil {
if s := strings.TrimSpace(*req.Status); s != "" {
p.Status = s
}
}
if req.Pinned != nil {
p.Pinned = *req.Pinned
@@ -139,6 +185,18 @@ type upsertFactRequest struct {
RelatedVulnerabilityID string `json:"related_vulnerability_id"`
}
// updateFactRequest 部分更新事实;指针字段省略=不修改,body 传 "" 可清空(仍走 merge 逻辑见 Upsert)。
type updateFactRequest struct {
FactKey *string `json:"fact_key"`
Category *string `json:"category"`
Summary *string `json:"summary"`
Body *string `json:"body"`
Confidence *string `json:"confidence"`
Pinned *bool `json:"pinned"`
RelatedVulnerabilityID *string `json:"related_vulnerability_id"`
ClearBody bool `json:"clear_body"`
}
// ListFacts GET /api/projects/:id/facts fact_key 查询参数可获取单条详情)
func (h *ProjectHandler) ListFacts(c *gin.Context) {
projectID := c.Param("id")
@@ -154,9 +212,13 @@ func (h *ProjectHandler) ListFacts(c *gin.Context) {
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100"))
offset, _ := strconv.Atoi(c.Query("offset"))
filter := database.ProjectFactListFilter{
Category: c.Query("category"),
Confidence: c.Query("confidence"),
Search: c.Query("search"),
Category: c.Query("category"),
Confidence: c.Query("confidence"),
Search: c.Query("search"),
RelatedVulnerabilityID: c.Query("related_vulnerability_id"),
}
if c.Query("exclude_deprecated") == "1" || c.Query("exclude_deprecated") == "true" {
filter.ExcludeDeprecated = true
}
list, err := h.db.ListProjectFacts(projectID, filter, limit, offset)
if err != nil {
@@ -166,6 +228,53 @@ func (h *ProjectHandler) ListFacts(c *gin.Context) {
if list == nil {
list = []*database.ProjectFact{}
}
if sparseOnly := c.Query("sparse_only"); sparseOnly == "1" || sparseOnly == "true" {
filtered := make([]*database.ProjectFact, 0, len(list))
for _, f := range list {
if project.IsSparseFactBody(f.Category, f.FactKey, f.Body) {
filtered = append(filtered, f)
}
}
list = filtered
}
c.JSON(http.StatusOK, list)
}
// GetFactPreviousVersion GET /api/projects/:id/facts/:factId/previous-version
func (h *ProjectHandler) GetFactPreviousVersion(c *gin.Context) {
existing, err := h.db.GetProjectFact(c.Param("factId"))
if err != nil || existing.ProjectID != c.Param("id") {
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
return
}
if strings.TrimSpace(existing.SupersedesFactID) == "" {
c.JSON(http.StatusNotFound, gin.H{"error": "无上一版本"})
return
}
v, err := h.db.GetProjectFactVersion(existing.SupersedesFactID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, v)
}
// ListFactVersions GET /api/projects/:id/facts/:factId/versions
func (h *ProjectHandler) ListFactVersions(c *gin.Context) {
existing, err := h.db.GetProjectFact(c.Param("factId"))
if err != nil || existing.ProjectID != c.Param("id") {
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
return
}
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
list, err := h.db.ListProjectFactVersions(existing.ID, limit)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if list == nil {
list = []*database.ProjectFactVersion{}
}
c.JSON(http.StatusOK, list)
}
@@ -201,28 +310,36 @@ func (h *ProjectHandler) UpdateFact(c *gin.Context) {
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
return
}
var req upsertFactRequest
var req updateFactRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if k := strings.TrimSpace(req.FactKey); k != "" {
existing.FactKey = k
if req.FactKey != nil {
if k := strings.TrimSpace(*req.FactKey); k != "" {
existing.FactKey = k
}
}
if req.Category != "" {
existing.Category = req.Category
if req.Category != nil && strings.TrimSpace(*req.Category) != "" {
existing.Category = *req.Category
}
if req.Summary != "" {
existing.Summary = req.Summary
if req.Summary != nil && strings.TrimSpace(*req.Summary) != "" {
existing.Summary = *req.Summary
}
if strings.TrimSpace(req.Body) != "" {
existing.Body = req.Body
if req.ClearBody {
existing.Body = ""
} else if req.Body != nil {
existing.Body = *req.Body
}
if req.Confidence != "" {
existing.Confidence = req.Confidence
if req.Confidence != nil && strings.TrimSpace(*req.Confidence) != "" {
existing.Confidence = *req.Confidence
}
if req.Pinned != nil {
existing.Pinned = *req.Pinned
}
if req.RelatedVulnerabilityID != nil {
existing.RelatedVulnerabilityID = *req.RelatedVulnerabilityID
}
existing.Pinned = req.Pinned
existing.RelatedVulnerabilityID = req.RelatedVulnerabilityID
updated, err := h.db.UpsertProjectFact(existing)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+1 -1
View File
@@ -23,7 +23,7 @@ func (h *AgentHandler) projectBlackboardBlock(conversationID string) string {
if err != nil || projectID == "" {
return ""
}
block, err := project.BuildFactIndexBlock(h.db, projectID, h.config.Project)
block, err := project.BuildProjectBlackboardBlock(h.db, projectID, h.config.Project)
if err != nil {
h.logger.Warn("构建项目黑板索引失败", zap.String("conversationId", conversationID), zap.Error(err))
return ""
+1 -1
View File
@@ -65,7 +65,7 @@ func BuildWebshellAssistantContext(conn *database.WebShellConnection, skillHint,
b.WriteString(conn.ID)
b.WriteString("\"):")
b.WriteString(webshellAssistantToolList)
b.WriteString("。")
b.WriteString("。边渗透边记录:每确认新认知即 upsert_project_fact,每验证漏洞即 record_vulnerability,勿等会话结束。")
b.WriteString(skillHint)
b.WriteString("\n\n用户请求:")
b.WriteString(userMsg)