mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-05-20 23:04:45 +02:00
249 lines
5.9 KiB
Go
249 lines
5.9 KiB
Go
package attackchain
|
|
|
|
import (
|
|
"strings"
|
|
"unicode/utf8"
|
|
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
const (
|
|
attackChainTruncationMarker = "\n\n...[攻击链输入已截断 / attack chain input truncated]...\n\n"
|
|
attackChainSystemReserve = 256
|
|
attackChainSafetyReserve = 2048
|
|
)
|
|
|
|
// attackChainMaxCompletionTokens 为攻击链 JSON 输出预留的 completion token 上限。
|
|
func attackChainMaxCompletionTokens(maxTotal int) int {
|
|
const capTokens = 16384
|
|
if maxTotal <= 0 {
|
|
return 8192
|
|
}
|
|
v := maxTotal / 8
|
|
if v < 4096 {
|
|
v = 4096
|
|
}
|
|
if v > capTokens {
|
|
v = capTokens
|
|
}
|
|
return v
|
|
}
|
|
|
|
func (b *Builder) modelName() string {
|
|
if b.openAIConfig != nil && b.openAIConfig.Model != "" {
|
|
return b.openAIConfig.Model
|
|
}
|
|
return "gpt-4"
|
|
}
|
|
|
|
func (b *Builder) countTokens(text string) int {
|
|
if text == "" {
|
|
return 0
|
|
}
|
|
n, err := b.tokenCounter.Count(b.modelName(), text)
|
|
if err != nil {
|
|
return utf8.RuneCountInString(text) / 4
|
|
}
|
|
return n
|
|
}
|
|
|
|
// attackChainPayloadTokenBudget 计算 reactInput + modelOutput 可用的 token 预算。
|
|
func (b *Builder) attackChainPayloadTokenBudget() int {
|
|
maxTotal := b.maxTokens
|
|
if maxTotal <= 0 {
|
|
maxTotal = 100000
|
|
}
|
|
templateTok := b.countTokens(b.buildSimplePrompt("", ""))
|
|
completion := attackChainMaxCompletionTokens(maxTotal)
|
|
reserve := templateTok + attackChainSystemReserve + completion + attackChainSafetyReserve
|
|
budget := maxTotal - reserve
|
|
minBudget := maxTotal * 35 / 100
|
|
if budget < minBudget {
|
|
budget = minBudget
|
|
}
|
|
if budget < 4096 {
|
|
budget = 4096
|
|
}
|
|
return budget
|
|
}
|
|
|
|
// fitAttackChainPayload 在构建最终 prompt 前压缩 ReAct 轨迹与模型输出,避免超出模型上下文。
|
|
func (b *Builder) fitAttackChainPayload(reactInput, modelOutput string) (string, string, bool) {
|
|
budget := b.attackChainPayloadTokenBudget()
|
|
modelBudget := budget * 15 / 100
|
|
if modelBudget < 512 {
|
|
modelBudget = 512
|
|
}
|
|
reactBudget := budget - modelBudget
|
|
|
|
origReactTok := b.countTokens(reactInput)
|
|
origModelTok := b.countTokens(modelOutput)
|
|
truncated := false
|
|
|
|
outModel := modelOutput
|
|
if origModelTok > modelBudget {
|
|
outModel = truncateTextByTokens(b, modelOutput, modelBudget)
|
|
truncated = true
|
|
}
|
|
|
|
outReact := reactInput
|
|
perToolLimits := []int{12000, 6000, 3000, 1500, 800}
|
|
for _, lim := range perToolLimits {
|
|
compact := compactFormattedToolBodies(outReact, lim)
|
|
if compact != outReact {
|
|
outReact = compact
|
|
truncated = true
|
|
}
|
|
if b.countTokens(outReact) <= reactBudget {
|
|
break
|
|
}
|
|
}
|
|
|
|
if b.countTokens(outReact) > reactBudget {
|
|
outReact = truncateTextByTokens(b, outReact, reactBudget)
|
|
truncated = true
|
|
}
|
|
|
|
if truncated {
|
|
b.logger.Info("攻击链输入已按 token 预算截断",
|
|
zap.Int("maxTotalTokens", b.maxTokens),
|
|
zap.Int("payloadBudget", budget),
|
|
zap.Int("reactBudget", reactBudget),
|
|
zap.Int("modelBudget", modelBudget),
|
|
zap.Int("reactInputTokensBefore", origReactTok),
|
|
zap.Int("reactInputTokensAfter", b.countTokens(outReact)),
|
|
zap.Int("modelOutputTokensBefore", origModelTok),
|
|
zap.Int("modelOutputTokensAfter", b.countTokens(outModel)),
|
|
zap.Int("maxCompletionTokens", attackChainMaxCompletionTokens(b.maxTokens)),
|
|
)
|
|
}
|
|
|
|
return outReact, outModel, truncated
|
|
}
|
|
|
|
// compactFormattedToolBodies 缩短格式化 trace 中 [tool] 消息的正文,保留工具头与调用 ID。
|
|
func compactFormattedToolBodies(s string, maxRunesPerBody int) string {
|
|
if maxRunesPerBody <= 0 || s == "" {
|
|
return s
|
|
}
|
|
const marker = "[tool]"
|
|
var out strings.Builder
|
|
remaining := s
|
|
changed := false
|
|
for {
|
|
idx := strings.Index(remaining, marker)
|
|
if idx < 0 {
|
|
out.WriteString(remaining)
|
|
break
|
|
}
|
|
out.WriteString(remaining[:idx])
|
|
remaining = remaining[idx:]
|
|
nl := strings.IndexByte(remaining, '\n')
|
|
if nl < 0 {
|
|
out.WriteString(remaining)
|
|
break
|
|
}
|
|
header := remaining[:nl+1]
|
|
remaining = remaining[nl+1:]
|
|
bodyEnd := strings.Index(remaining, "\n\n[")
|
|
var body, rest string
|
|
if bodyEnd < 0 {
|
|
body = remaining
|
|
rest = ""
|
|
} else {
|
|
body = remaining[:bodyEnd]
|
|
rest = remaining[bodyEnd:]
|
|
}
|
|
if runeLen(body) > maxRunesPerBody {
|
|
body = truncateRunesWithNotice(body, maxRunesPerBody)
|
|
changed = true
|
|
}
|
|
out.WriteString(header)
|
|
out.WriteString(body)
|
|
remaining = rest
|
|
if rest == "" {
|
|
break
|
|
}
|
|
}
|
|
if !changed {
|
|
return s
|
|
}
|
|
return out.String()
|
|
}
|
|
|
|
func truncateTextByTokens(b *Builder, text string, maxTokens int) string {
|
|
if maxTokens <= 0 || text == "" {
|
|
return ""
|
|
}
|
|
if b.countTokens(text) <= maxTokens {
|
|
return text
|
|
}
|
|
markerTok := b.countTokens(attackChainTruncationMarker)
|
|
usable := maxTokens - markerTok
|
|
if usable < 256 {
|
|
usable = maxTokens / 2
|
|
}
|
|
headBudget := usable * 60 / 100
|
|
tailBudget := usable - headBudget
|
|
head := takeTokensFromStart(b, text, headBudget)
|
|
tail := takeTokensFromEnd(b, text, tailBudget)
|
|
return head + attackChainTruncationMarker + tail
|
|
}
|
|
|
|
func takeTokensFromStart(b *Builder, text string, maxTokens int) string {
|
|
rs := []rune(text)
|
|
if len(rs) == 0 || maxTokens <= 0 {
|
|
return ""
|
|
}
|
|
lo, hi := 0, len(rs)
|
|
for lo < hi {
|
|
mid := (lo + hi + 1) / 2
|
|
if b.countTokens(string(rs[:mid])) <= maxTokens {
|
|
lo = mid
|
|
} else {
|
|
hi = mid - 1
|
|
}
|
|
}
|
|
return string(rs[:lo])
|
|
}
|
|
|
|
func takeTokensFromEnd(b *Builder, text string, maxTokens int) string {
|
|
rs := []rune(text)
|
|
if len(rs) == 0 || maxTokens <= 0 {
|
|
return ""
|
|
}
|
|
lo, hi := 0, len(rs)
|
|
for lo < hi {
|
|
mid := (lo + hi) / 2
|
|
if b.countTokens(string(rs[mid:])) <= maxTokens {
|
|
hi = mid
|
|
} else {
|
|
lo = mid + 1
|
|
}
|
|
}
|
|
return string(rs[lo:])
|
|
}
|
|
|
|
func truncateRunesWithNotice(s string, maxRunes int) string {
|
|
rs := []rune(s)
|
|
if len(rs) <= maxRunes {
|
|
return s
|
|
}
|
|
const notice = "\n...[工具输出已截断 / tool output truncated]...\n"
|
|
noticeRunes := []rune(notice)
|
|
keep := maxRunes - len(noticeRunes)
|
|
if keep < 200 {
|
|
keep = maxRunes * 2 / 3
|
|
}
|
|
if keep < 1 {
|
|
return notice
|
|
}
|
|
head := keep * 70 / 100
|
|
tail := keep - head
|
|
return string(rs[:head]) + notice + string(rs[len(rs)-tail:])
|
|
}
|
|
|
|
func runeLen(s string) int {
|
|
return len([]rune(s))
|
|
}
|