diff --git a/internal/project/user_verbatim_anchor.go b/internal/project/user_verbatim_anchor.go new file mode 100644 index 00000000..bed7676f --- /dev/null +++ b/internal/project/user_verbatim_anchor.go @@ -0,0 +1,170 @@ +package project + +import ( + "fmt" + "strings" + + "cyberstrike-ai/internal/database" + + "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/schema" +) + +const ( + // UserVerbatimSectionHeading 用户原文锚点可读标题(块内保留,供 Agent 阅读)。 + UserVerbatimSectionHeading = "## 用户历史输入(原文保留,勿省略或改写)" + + // UserVerbatimSectionStartMarker / EndMarker:HTML 注释边界,供程序化替换;对模型无指令语义。 + UserVerbatimSectionStartMarker = "" + UserVerbatimSectionEndMarker = "" +) + +// ExtractUserContentsFromMessages 按时间顺序提取 user 角色消息的原文(跳过空白)。 +func ExtractUserContentsFromMessages(msgs []database.Message) []string { + out := make([]string, 0, len(msgs)) + for i := range msgs { + if !strings.EqualFold(strings.TrimSpace(msgs[i].Role), "user") { + continue + } + content := strings.TrimSpace(msgs[i].Content) + if content == "" { + continue + } + out = append(out, content) + } + return out +} + +// BuildUserVerbatimAnchorBlockFromMessages 从 messages 表行构建用户原文锚点块。 +// maxRunes: 0 = 不截断;>0 = 总 rune 上限(仍保留每一轮,仅对超长单条做尾部截断提示)。 +func BuildUserVerbatimAnchorBlockFromMessages(msgs []database.Message, maxRunes int) string { + return BuildUserVerbatimAnchorBlock(ExtractUserContentsFromMessages(msgs), maxRunes) +} + +// BuildUserVerbatimAnchorBlock 将各轮用户原文格式化为 system prompt 锚点块。 +func BuildUserVerbatimAnchorBlock(userContents []string, maxRunes int) string { + if len(userContents) == 0 { + return "" + } + lines := make([]string, 0, len(userContents)) + for _, content := range userContents { + content = strings.TrimSpace(content) + if content == "" { + continue + } + lines = append(lines, fmt.Sprintf("[第%d轮] %s", len(lines)+1, content)) + } + if len(lines) == 0 { + return "" + } + body := strings.Join(lines, "\n") + if maxRunes > 0 { + body = capUserVerbatimBody(body, maxRunes) + } + return wrapUserVerbatimBlock(UserVerbatimSectionHeading + "\n\n" + body) +} + +func capUserVerbatimBody(body string, maxRunes int) string { + rs := []rune(body) + if len(rs) <= maxRunes { + return body + } + suffix := "\n\n...(用户原文锚点已达配置上限,更早轮次可能被截断;完整原文见 messages 表)..." + suffixRunes := []rune(suffix) + keep := maxRunes - len(suffixRunes) + if keep <= 0 { + return string(rs[:maxRunes]) + } + return string(rs[:keep]) + suffix +} + +func wrapUserVerbatimBlock(content string) string { + content = strings.TrimSpace(content) + if content == "" { + return "" + } + return UserVerbatimSectionStartMarker + "\n" + content + "\n" + UserVerbatimSectionEndMarker + "\n" +} + +// ReplaceUserVerbatimAnchorSection 用 freshBlock 替换 content 中已有的用户原文锚点段。 +func ReplaceUserVerbatimAnchorSection(content, freshBlock string) (string, bool) { + content = strings.TrimSpace(content) + freshBlock = strings.TrimSpace(freshBlock) + if freshBlock == "" { + return content, false + } + start, ok := userVerbatimSectionStart(content) + if !ok { + return content, false + } + end, ok := userVerbatimSectionEnd(content, start) + if !ok { + return content, false + } + return strings.TrimSpace(content[:start] + freshBlock + content[end:]), true +} + +func userVerbatimSectionStart(content string) (int, bool) { + idx := strings.Index(content, UserVerbatimSectionStartMarker) + if idx < 0 { + return 0, false + } + return idx, true +} + +func userVerbatimSectionEnd(content string, start int) (int, bool) { + if start < 0 || start >= len(content) { + return 0, false + } + tail := content[start:] + idx := strings.LastIndex(tail, UserVerbatimSectionEndMarker) + if idx < 0 { + return 0, false + } + return start + idx + len(UserVerbatimSectionEndMarker), true +} + +// RefreshUserVerbatimAnchorInMessages 在 summarization 等压缩后,用 freshBlock 刷新 system 中的用户原文锚点。 +// 若尚无锚点段,则追加到首条 system 消息;若无 system 消息则在开头插入一条。 +func RefreshUserVerbatimAnchorInMessages(msgs []adk.Message, freshBlock string) []adk.Message { + freshBlock = strings.TrimSpace(freshBlock) + if freshBlock == "" || len(msgs) == 0 { + return msgs + } + + out := make([]adk.Message, len(msgs)) + changed := false + for i, msg := range msgs { + if msg == nil || msg.Role != schema.System { + out[i] = msg + continue + } + newContent, ok := ReplaceUserVerbatimAnchorSection(msg.Content, freshBlock) + if !ok { + out[i] = msg + continue + } + cloned := *msg + cloned.Content = newContent + out[i] = &cloned + changed = true + } + + if changed { + return out + } + + for i, msg := range msgs { + if msg == nil || msg.Role != schema.System { + continue + } + cloned := *msg + cloned.Content = AppendSystemPromptBlock(cloned.Content, freshBlock) + out[i] = &cloned + return out + } + + prefix := make([]adk.Message, 0, len(msgs)+1) + prefix = append(prefix, schema.SystemMessage(freshBlock)) + return append(prefix, msgs...) +} diff --git a/internal/project/user_verbatim_anchor_test.go b/internal/project/user_verbatim_anchor_test.go new file mode 100644 index 00000000..2ace353f --- /dev/null +++ b/internal/project/user_verbatim_anchor_test.go @@ -0,0 +1,96 @@ +package project + +import ( + "strings" + "testing" + + "cyberstrike-ai/internal/database" + + "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/schema" +) + +func TestBuildUserVerbatimAnchorBlock_MultiTurn(t *testing.T) { + msgs := []database.Message{ + {Role: "user", Content: "目标 https://a.com 仅测 /api"}, + {Role: "assistant", Content: "好的"}, + {Role: "user", Content: "用 admin:test 登录"}, + } + block := BuildUserVerbatimAnchorBlockFromMessages(msgs, 0) + if block == "" { + t.Fatal("expected non-empty block") + } + if !strings.Contains(block, UserVerbatimSectionStartMarker) { + t.Error("missing start marker") + } + if !strings.Contains(block, "[第1轮]") || !strings.Contains(block, "https://a.com") { + t.Error("missing first user turn") + } + if !strings.Contains(block, "[第2轮]") || !strings.Contains(block, "admin:test") { + t.Error("missing second user turn") + } + if strings.Contains(block, "好的") { + t.Error("assistant content should not appear") + } +} + +func TestReplaceUserVerbatimAnchorSection(t *testing.T) { + old := "prefix\n\n" + wrapUserVerbatimBlock("## old\n\n[第1轮] a") + "\nsuffix" + newBlock := wrapUserVerbatimBlock(UserVerbatimSectionHeading + "\n\n[第1轮] b\n[第2轮] c") + out, ok := ReplaceUserVerbatimAnchorSection(old, newBlock) + if !ok { + t.Fatal("expected replace ok") + } + if !strings.Contains(out, "[第2轮] c") { + t.Errorf("expected new block, got %q", out) + } + if !strings.HasPrefix(strings.TrimSpace(out), "prefix") { + t.Error("prefix should remain") + } + if !strings.Contains(out, "suffix") { + t.Error("suffix should remain") + } +} + +func TestRefreshUserVerbatimAnchorInMessages_ReplaceExisting(t *testing.T) { + oldBlock := wrapUserVerbatimBlock(UserVerbatimSectionHeading + "\n\n[第1轮] old") + msgs := []adk.Message{ + schema.SystemMessage("instr\n\n" + oldBlock), + schema.UserMessage("hi"), + } + newBlock := wrapUserVerbatimBlock(UserVerbatimSectionHeading + "\n\n[第1轮] new") + out := RefreshUserVerbatimAnchorInMessages(msgs, newBlock) + if len(out) != 2 { + t.Fatalf("message count: got %d", len(out)) + } + if !strings.Contains(out[0].Content, "[第1轮] new") { + t.Errorf("system content: %q", out[0].Content) + } + if strings.Contains(out[0].Content, "[第1轮] old") { + t.Error("old anchor should be replaced") + } +} + +func TestRefreshUserVerbatimAnchorInMessages_InsertWhenMissing(t *testing.T) { + msgs := []adk.Message{ + schema.SystemMessage("base instruction"), + schema.UserMessage("hi"), + } + block := wrapUserVerbatimBlock(UserVerbatimSectionHeading + "\n\n[第1轮] anchor") + out := RefreshUserVerbatimAnchorInMessages(msgs, block) + if !strings.Contains(out[0].Content, "[第1轮] anchor") { + t.Errorf("expected appended anchor, got %q", out[0].Content) + } +} + +func TestBuildUserVerbatimAnchorBlock_MaxRunes(t *testing.T) { + long := strings.Repeat("字", 200) + block := BuildUserVerbatimAnchorBlock([]string{long}, 50) + body := block + if idx := strings.Index(body, UserVerbatimSectionStartMarker); idx >= 0 { + body = strings.TrimPrefix(body[idx+len(UserVerbatimSectionStartMarker):], "\n") + } + if len([]rune(body)) > 120 { + t.Errorf("expected capped body, got %d runes", len([]rune(body))) + } +}