Files
CyberStrikeAI/internal/agent/token_counter.go
T
2026-06-02 13:25:00 +08:00

55 lines
1.1 KiB
Go

package agent
import (
"sync"
"github.com/pkoukk/tiktoken-go"
)
// TokenCounter 估算文本 token 数(tiktoken;模型未知时回退 cl100k_base)。
type TokenCounter interface {
Count(model, text string) (int, error)
}
type tikTokenCounter struct {
mu sync.Mutex
cache map[string]*tiktoken.Tiktoken
}
// NewTikTokenCounter 创建基于 tiktoken 的 TokenCounter。
func NewTikTokenCounter() TokenCounter {
return &tikTokenCounter{cache: make(map[string]*tiktoken.Tiktoken)}
}
func (c *tikTokenCounter) encoding(model string) (*tiktoken.Tiktoken, error) {
key := model
if key == "" {
key = "cl100k_base"
}
c.mu.Lock()
defer c.mu.Unlock()
if enc, ok := c.cache[key]; ok {
return enc, nil
}
enc, err := tiktoken.EncodingForModel(key)
if err != nil {
enc, err = tiktoken.GetEncoding("cl100k_base")
}
if err != nil {
return nil, err
}
c.cache[key] = enc
return enc, nil
}
func (c *tikTokenCounter) Count(model, text string) (int, error) {
if text == "" {
return 0, nil
}
enc, err := c.encoding(model)
if err != nil {
return 0, err
}
return len(enc.Encode(text, nil, nil)), nil
}