Add files via upload

This commit is contained in:
公明
2026-06-03 17:16:48 +08:00
committed by GitHub
parent da8fdafe59
commit e606369e31
11 changed files with 1041 additions and 2 deletions
+1 -1
View File
@@ -491,7 +491,7 @@ func appendAttachmentsToMessage(msg string, attachments []ChatAttachment, savedP
}
var b strings.Builder
b.WriteString(msg)
b.WriteString("\n\n[用户上传的文件已保存到以下路径(请按需读取文件内容,而不是依赖内联内容)]\n")
b.WriteString("\n\n[用户上传的文件]\n")
for i, a := range attachments {
if i < len(savedPaths) && savedPaths[i] != "" {
b.WriteString(fmt.Sprintf("- %s: %s\n", a.FileName, savedPaths[i]))
+147
View File
@@ -237,6 +237,7 @@ func (h *ConfigHandler) ApplyWechatRobotBinding(wc config.RobotWechatConfig) err
// GetConfigResponse 获取配置响应
type GetConfigResponse struct {
OpenAI config.OpenAIConfig `json:"openai"`
Vision config.VisionConfig `json:"vision"`
FOFA config.FofaConfig `json:"fofa"`
MCP config.MCPConfig `json:"mcp"`
Tools []ToolConfigInfo `json:"tools"`
@@ -333,6 +334,7 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
c.JSON(http.StatusOK, GetConfigResponse{
OpenAI: h.config.OpenAI,
Vision: h.config.Vision,
FOFA: h.config.FOFA,
MCP: h.config.MCP,
Tools: tools,
@@ -638,6 +640,7 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
// UpdateConfigRequest 更新配置请求
type UpdateConfigRequest struct {
OpenAI *config.OpenAIConfig `json:"openai,omitempty"`
Vision *config.VisionConfig `json:"vision,omitempty"`
FOFA *config.FofaConfig `json:"fofa,omitempty"`
MCP *config.MCPConfig `json:"mcp,omitempty"`
Tools []ToolEnableStatus `json:"tools,omitempty"`
@@ -707,6 +710,14 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
)
}
if req.Vision != nil {
h.config.Vision = *req.Vision
h.logger.Info("更新 Vision 配置",
zap.Bool("enabled", h.config.Vision.Enabled),
zap.String("model", h.config.Vision.Model),
)
}
// 更新FOFA配置
if req.FOFA != nil {
h.config.FOFA = *req.FOFA
@@ -1031,6 +1042,99 @@ func (h *ConfigHandler) TestOpenAI(c *gin.Context) {
})
}
// TestVisionRequest 测试 Vision 模型连接;vision.api_key/base_url 留空时可传 openai 段作回退。
type TestVisionRequest struct {
Vision config.VisionConfig `json:"vision"`
OpenAI config.OpenAIConfig `json:"openai,omitempty"`
}
// TestVision 测试视觉模型 API 连接(最小 chat completion)。
func (h *ConfigHandler) TestVision(c *gin.Context) {
var req TestVisionRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
oa := req.Vision.OpenAICfgEffective(req.OpenAI)
if strings.TrimSpace(oa.APIKey) == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "API Key 不能为空(可填写 vision.api_key 或 openai.api_key"})
return
}
if strings.TrimSpace(oa.Model) == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "视觉模型不能为空"})
return
}
baseURL := strings.TrimSuffix(strings.TrimSpace(oa.BaseURL), "/")
if baseURL == "" {
if strings.EqualFold(strings.TrimSpace(oa.Provider), "claude") {
baseURL = "https://api.anthropic.com"
} else {
baseURL = "https://api.openai.com/v1"
}
}
payload := map[string]interface{}{
"model": oa.Model,
"messages": []map[string]string{
{"role": "user", "content": "Hi"},
},
"max_completion_tokens": 5,
}
tmpCfg := &config.OpenAIConfig{
Provider: oa.Provider,
BaseURL: baseURL,
APIKey: strings.TrimSpace(oa.APIKey),
Model: oa.Model,
}
client := openai.NewClient(tmpCfg, nil, h.logger)
ctx, cancel := context.WithTimeout(c.Request.Context(), 30*time.Second)
defer cancel()
start := time.Now()
var chatResp struct {
Model string `json:"model"`
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
err := client.ChatCompletion(ctx, payload, &chatResp)
latency := time.Since(start)
if err != nil {
if apiErr, ok := err.(*openai.APIError); ok {
c.JSON(http.StatusOK, gin.H{
"success": false,
"error": fmt.Sprintf("API 返回错误 (HTTP %d): %s", apiErr.StatusCode, apiErr.Body),
"status_code": apiErr.StatusCode,
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": false,
"error": "连接失败: " + err.Error(),
})
return
}
if len(chatResp.Choices) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"error": "API 响应缺少 choices 字段,请检查 Base URL 与视觉模型名称",
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"model": chatResp.Model,
"latency_ms": latency.Milliseconds(),
})
}
// ApplyConfig 应用配置(重新加载并重启相关服务)
func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
// 先检查是否需要动态初始化知识库(在锁外执行,避免阻塞其他请求)
@@ -1286,6 +1390,7 @@ func (h *ConfigHandler) saveConfig() error {
updateAgentConfig(root, h.config.Agent)
updateMCPConfig(root, h.config.MCP)
updateOpenAIConfig(root, h.config.OpenAI)
updateVisionConfig(root, h.config.Vision)
updateFOFAConfig(root, h.config.FOFA)
updateKnowledgeConfig(root, h.config.Knowledge)
updateC2Config(root, h.config.C2)
@@ -1406,6 +1511,48 @@ func updateMCPConfig(doc *yaml.Node, cfg config.MCPConfig) {
setIntInMap(mcpNode, "port", cfg.Port)
}
func updateVisionConfig(doc *yaml.Node, cfg config.VisionConfig) {
root := doc.Content[0]
visionNode := ensureMap(root, "vision")
setBoolInMap(visionNode, "enabled", cfg.Enabled)
if strings.TrimSpace(cfg.APIKey) != "" {
setStringInMap(visionNode, "api_key", cfg.APIKey)
} else {
setStringInMap(visionNode, "api_key", "")
}
if strings.TrimSpace(cfg.BaseURL) != "" {
setStringInMap(visionNode, "base_url", cfg.BaseURL)
} else {
setStringInMap(visionNode, "base_url", "")
}
setStringInMap(visionNode, "model", cfg.Model)
if strings.TrimSpace(cfg.Provider) != "" {
setStringInMap(visionNode, "provider", cfg.Provider)
}
if cfg.TimeoutSeconds > 0 {
setIntInMap(visionNode, "timeout_seconds", cfg.TimeoutSeconds)
}
if cfg.MaxImageBytes > 0 {
setIntInMap(visionNode, "max_image_bytes", int(cfg.MaxImageBytes))
}
if cfg.MaxDimension > 0 {
setIntInMap(visionNode, "max_dimension", cfg.MaxDimension)
}
if cfg.JPEGQuality > 0 {
setIntInMap(visionNode, "jpeg_quality", cfg.JPEGQuality)
}
if cfg.MaxPayloadBytes > 0 {
setIntInMap(visionNode, "max_payload_bytes", int(cfg.MaxPayloadBytes))
}
setIntInMap(visionNode, "skip_preprocess_below_bytes", int(cfg.SkipPreprocessBelowBytes))
if strings.TrimSpace(cfg.Detail) != "" {
setStringInMap(visionNode, "detail", cfg.Detail)
}
if len(cfg.AllowedRoots) > 0 {
setStringSliceInMap(visionNode, "allowed_roots", cfg.AllowedRoots)
}
}
func updateOpenAIConfig(doc *yaml.Node, cfg config.OpenAIConfig) {
root := doc.Content[0]
openaiNode := ensureMap(root, "openai")
+91 -1
View File
@@ -778,11 +778,55 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
},
"ConfigResponse": map[string]interface{}{
"type": "object",
"description": "配置信息",
"description": "配置信息(含 openai、vision、multi_agent 等)",
"properties": map[string]interface{}{
"vision": map[string]interface{}{
"$ref": "#/components/schemas/VisionConfig",
},
},
},
"UpdateConfigRequest": map[string]interface{}{
"type": "object",
"description": "更新配置请求",
"properties": map[string]interface{}{
"vision": map[string]interface{}{
"$ref": "#/components/schemas/VisionConfig",
},
},
},
"VisionConfig": map[string]interface{}{
"type": "object",
"description": "视觉分析(analyze_image MCP 工具);enabled 且 model 非空时注册工具",
"properties": map[string]interface{}{
"enabled": map[string]interface{}{"type": "boolean", "description": "是否启用 analyze_image"},
"model": map[string]interface{}{"type": "string", "description": "视觉模型名(必填)", "example": "qwen-vl-max"},
"api_key": map[string]interface{}{"type": "string", "description": "API Key;留空复用 openai.api_key"},
"base_url": map[string]interface{}{"type": "string", "description": "Base URL;留空复用 openai.base_url"},
"provider": map[string]interface{}{"type": "string", "description": "提供商;留空复用 openai.provider"},
"timeout_seconds": map[string]interface{}{"type": "integer", "description": "VL 调用超时(秒)"},
"max_image_bytes": map[string]interface{}{"type": "integer", "description": "原始文件大小上限(字节)"},
"max_dimension": map[string]interface{}{"type": "integer", "description": "长边缩放像素"},
"jpeg_quality": map[string]interface{}{"type": "integer", "description": "JPEG 质量 60-100"},
"max_payload_bytes": map[string]interface{}{"type": "integer", "description": "送 API 体积上限(字节)"},
"skip_preprocess_below_bytes": map[string]interface{}{"type": "integer", "description": "低于该字节且尺寸合规时可原图直传;0=始终压缩"},
"detail": map[string]interface{}{"type": "string", "enum": []string{"low", "high", "auto"}, "description": "OpenAI 兼容 image detail"},
"allowed_roots": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}, "description": "额外允许读取的绝对路径根"},
},
},
"AnalyzeImageToolCall": map[string]interface{}{
"type": "object",
"description": "内置 MCP 工具 analyze_image:分析服务器本地图片,返回纯文本(验证码/UI/报错等)",
"properties": map[string]interface{}{
"path": map[string]interface{}{
"type": "string",
"description": "图片路径(cwd、chat_uploads、result_storage_dir 或 allowed_roots 下)",
},
"question": map[string]interface{}{
"type": "string",
"description": "可选:重点问题;验证码建议「只输出验证码字符」",
},
},
"required": []string{"path"},
},
"ExternalMCPConfig": map[string]interface{}{
"type": "object",
@@ -4900,6 +4944,52 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
},
// ==================== 配置管理 - 缺失端点 ====================
"/api/config/test-vision": map[string]interface{}{
"post": map[string]interface{}{
"tags": []string{"配置管理"},
"summary": "测试视觉模型连接",
"description": "测试 Vision 模型 API 是否可用。vision.api_key/base_url 留空时可传 openai 段作回退。",
"operationId": "testVision",
"requestBody": map[string]interface{}{
"required": true,
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"required": []string{"vision"},
"properties": map[string]interface{}{
"vision": map[string]interface{}{"$ref": "#/components/schemas/VisionConfig"},
"openai": map[string]interface{}{
"type": "object",
"description": "主 LLM 配置(vision 字段留空时用于 API Key/Base URL 回退)",
},
},
},
},
},
},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "测试结果",
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"success": map[string]interface{}{"type": "boolean"},
"error": map[string]interface{}{"type": "string"},
"model": map[string]interface{}{"type": "string"},
"latency_ms": map[string]interface{}{"type": "number"},
},
},
},
},
},
"400": map[string]interface{}{"description": "参数错误"},
"401": map[string]interface{}{"description": "未授权"},
},
},
},
"/api/config/test-openai": map[string]interface{}{
"post": map[string]interface{}{
"tags": []string{"配置管理"},
+22
View File
@@ -0,0 +1,22 @@
package project
import "strings"
// VisionImageAnalysisSection 单/多代理共用的图片分析提示(analyze_image;上下文仅保留文字摘要)。
func VisionImageAnalysisSection() string {
var b strings.Builder
b.WriteString("## 图片分析\n\n")
b.WriteString("- 遇到图片文件(截图、验证码、登录页、报告配图)时,若存在工具 analyze_image,请传入服务器上的文件路径进行分析。\n")
b.WriteString("- 不要对二进制图片使用 read_file 指望理解内容;用户消息中「📎 xxx.png: /path」即为可传给 analyze_image 的路径。\n")
b.WriteString("- 验证码类:若已从页面或接口保存为本地图片(如 captcha.png),用 analyze_imagequestion 写明「只输出验证码字符」;识别失败则刷新验证码后重新保存再识;复杂滑块/行为验证码勿指望单次识图成功。\n")
b.WriteString("- 委派子代理时,若子任务含验证码/截图识读,在 task description 中写明图片路径与期望输出格式。\n")
return b.String()
}
// AppendVisionImageAnalysisIfReady 仅在 vision.enabled 且 model 已配置时追加图片分析提示。
func AppendVisionImageAnalysisIfReady(base string, visionReady bool) string {
if !visionReady {
return base
}
return AppendSystemPromptBlock(base, VisionImageAnalysisSection())
}
+132
View File
@@ -0,0 +1,132 @@
package vision
import (
"context"
"encoding/base64"
"fmt"
"net"
"net/http"
"strings"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/openai"
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino/schema"
)
// Client 调用独立 Vision ChatModel(单次 Generate)。
type Client struct {
cfg config.VisionConfig
mainOA config.OpenAIConfig
}
// NewClient 构造视觉客户端。
func NewClient(visionCfg config.VisionConfig, mainOpenAI config.OpenAIConfig) *Client {
return &Client{cfg: visionCfg, mainOA: mainOpenAI}
}
// Analyze 将图片字节送入 VL 模型并返回文本描述。
func (c *Client) Analyze(ctx context.Context, img ImagePayload, question string) (string, error) {
if len(img.Bytes) == 0 {
return "", fmt.Errorf("empty image payload")
}
mime := strings.TrimSpace(img.MIMEType)
if mime == "" {
mime = "image/jpeg"
}
oa := c.cfg.OpenAICfgEffective(c.mainOA)
if strings.TrimSpace(oa.APIKey) == "" {
return "", fmt.Errorf("vision API key is empty (set vision.api_key or openai.api_key)")
}
if strings.TrimSpace(oa.Model) == "" {
return "", fmt.Errorf("vision model is empty")
}
timeout := time.Duration(c.cfg.TimeoutSecondsEffective()) * time.Second
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
httpClient := &http.Client{
Timeout: timeout + 15*time.Second,
Transport: &http.Transport{
DialContext: (&net.Dialer{
Timeout: 60 * time.Second,
KeepAlive: 60 * time.Second,
}).DialContext,
ResponseHeaderTimeout: timeout + 10*time.Second,
},
}
httpClient = openai.NewEinoHTTPClient(&oa, httpClient)
modelCfg := &einoopenai.ChatModelConfig{
APIKey: oa.APIKey,
BaseURL: strings.TrimSuffix(oa.BaseURL, "/"),
Model: oa.Model,
HTTPClient: httpClient,
}
chatModel, err := einoopenai.NewChatModel(ctx, modelCfg)
if err != nil {
return "", fmt.Errorf("vision chat model: %w", err)
}
b64 := base64.StdEncoding.EncodeToString(img.Bytes)
detail := schema.ImageURLDetailLow
switch c.cfg.DetailEffective() {
case "high":
detail = schema.ImageURLDetailHigh
case "auto":
detail = schema.ImageURLDetailAuto
}
prompt := buildVisionPrompt(question)
userMsg := &schema.Message{
Role: schema.User,
UserInputMultiContent: []schema.MessageInputPart{
{Type: schema.ChatMessagePartTypeText, Text: prompt},
{
Type: schema.ChatMessagePartTypeImageURL,
Image: &schema.MessageInputImage{
MessagePartCommon: schema.MessagePartCommon{
Base64Data: &b64,
MIMEType: mime,
},
Detail: detail,
},
},
},
}
resp, err := chatModel.Generate(ctx, []*schema.Message{userMsg})
if err != nil {
return "", fmt.Errorf("vision generate: %w", err)
}
if resp == nil || strings.TrimSpace(resp.Content) == "" {
return "", fmt.Errorf("vision model returned empty content")
}
return strings.TrimSpace(resp.Content), nil
}
func buildVisionPrompt(question string) string {
q := strings.TrimSpace(question)
if q == "" {
q = "请对图片做通用描述,侧重授权安全测试场景(可见文本、表单、按钮、验证码、错误信息、技术栈线索)。"
}
extra := ""
if looksLikeCaptchaQuestion(q) {
extra = "\n若为验证码:仅输出你辨认出的字符序列,不要空格、标点、解释;看不清则明确说无法识别。"
}
return `你是授权安全测试助手请根据图片回答用户问题只描述你能从图中确认的内容不要编造
用户问题` + q + extra
}
func looksLikeCaptchaQuestion(q string) bool {
s := strings.ToLower(q)
for _, kw := range []string{"验证码", "captcha", "verification code", "verify code", "vcode", "图形码"} {
if strings.Contains(s, kw) {
return true
}
}
return strings.Contains(s, "只输出") && (strings.Contains(s, "字符") || strings.Contains(s, "character"))
}
+12
View File
@@ -0,0 +1,12 @@
package vision
import "testing"
func TestLooksLikeCaptchaQuestion(t *testing.T) {
if !looksLikeCaptchaQuestion("识别验证码,只输出字符") {
t.Fatal("expected captcha hint")
}
if looksLikeCaptchaQuestion("描述登录页布局") {
t.Fatal("expected non-captcha")
}
}
+142
View File
@@ -0,0 +1,142 @@
package vision
import (
"fmt"
"os"
"path/filepath"
"strings"
)
const chatUploadsDirName = "chat_uploads"
var allowedImageExt = map[string]struct{}{
".png": {}, ".jpg": {}, ".jpeg": {}, ".webp": {}, ".gif": {},
".bmp": {}, ".tif": {}, ".tiff": {},
}
// PathOptions 图片路径白名单根目录。
type PathOptions struct {
CWD string
ResultStorageDir string // 相对 CWD,如 tmp
ExtraRoots []string // vision.allowed_roots 绝对路径
}
// ResolveImagePath 解析并校验可读图片路径(防穿越、symlink 逃逸)。
func ResolveImagePath(path string, opt PathOptions) (string, error) {
p := strings.TrimSpace(path)
if p == "" {
return "", fmt.Errorf("path is empty")
}
cwd := strings.TrimSpace(opt.CWD)
if cwd == "" {
var err error
cwd, err = os.Getwd()
if err != nil {
return "", fmt.Errorf("getwd: %w", err)
}
}
cwdAbs, err := filepath.Abs(filepath.Clean(cwd))
if err != nil {
return "", err
}
var candidate string
if filepath.IsAbs(p) {
candidate = filepath.Clean(p)
} else {
candidate = filepath.Clean(filepath.Join(cwdAbs, p))
}
candidate = normalizeAbsPath(candidate)
if candidate == "" {
return "", fmt.Errorf("invalid path")
}
ext := strings.ToLower(filepath.Ext(candidate))
if _, ok := allowedImageExt[ext]; !ok {
return "", fmt.Errorf("unsupported image extension %q", ext)
}
roots := buildAllowedRoots(cwdAbs, opt)
resolved, err := evalUnderAllowedRoots(candidate, roots)
if err != nil {
return "", err
}
st, err := os.Stat(resolved)
if err != nil {
return "", fmt.Errorf("stat: %w", err)
}
if st.IsDir() {
return "", fmt.Errorf("not a regular file")
}
if st.Size() > 0 && st.Size() > 1<<30 {
return "", fmt.Errorf("file too large on disk")
}
return resolved, nil
}
func normalizeAbsPath(p string) string {
abs, err := filepath.Abs(filepath.Clean(p))
if err != nil {
return ""
}
if link, err := filepath.EvalSymlinks(abs); err == nil {
return link
}
return abs
}
func buildAllowedRoots(cwdAbs string, opt PathOptions) []string {
seen := make(map[string]struct{})
var roots []string
add := func(r string) {
r = strings.TrimSpace(r)
if r == "" {
return
}
abs := normalizeAbsPath(r)
if abs == "" {
return
}
if _, ok := seen[abs]; ok {
return
}
seen[abs] = struct{}{}
roots = append(roots, abs)
}
add(cwdAbs)
add(filepath.Join(cwdAbs, chatUploadsDirName))
rs := strings.TrimSpace(opt.ResultStorageDir)
if rs == "" {
rs = "tmp"
}
if filepath.IsAbs(rs) {
add(rs)
} else {
add(filepath.Join(cwdAbs, rs))
}
for _, r := range opt.ExtraRoots {
add(r)
}
return roots
}
func evalUnderAllowedRoots(candidate string, roots []string) (string, error) {
check := normalizeAbsPath(candidate)
for _, root := range roots {
if isUnderRoot(check, root) {
return candidate, nil
}
}
return "", fmt.Errorf("path %q is outside allowed directories", candidate)
}
func isUnderRoot(path, root string) bool {
path = filepath.Clean(path)
root = filepath.Clean(root)
if path == root {
return true
}
sep := string(filepath.Separator)
return strings.HasPrefix(path, root+sep)
}
+43
View File
@@ -0,0 +1,43 @@
package vision
import (
"os"
"path/filepath"
"testing"
)
func TestResolveImagePath_underCWD(t *testing.T) {
dir := t.TempDir()
img := filepath.Join(dir, "shot.png")
if err := os.WriteFile(img, []byte{0x89, 0x50, 0x4e, 0x47}, 0o644); err != nil {
t.Fatal(err)
}
got, err := ResolveImagePath(img, PathOptions{CWD: dir, ResultStorageDir: "tmp"})
if err != nil {
t.Fatal(err)
}
want := normalizeAbsPath(img)
if got != want {
t.Fatalf("got %q want %q", got, want)
}
}
func TestResolveImagePath_rejectsTraversal(t *testing.T) {
dir := t.TempDir()
_, err := ResolveImagePath("../../../etc/passwd", PathOptions{CWD: dir})
if err == nil {
t.Fatal("expected error for path outside roots")
}
}
func TestResolveImagePath_rejectsNonImageExt(t *testing.T) {
dir := t.TempDir()
f := filepath.Join(dir, "notes.txt")
if err := os.WriteFile(f, []byte("x"), 0o644); err != nil {
t.Fatal(err)
}
_, err := ResolveImagePath(f, PathOptions{CWD: dir})
if err == nil {
t.Fatal("expected error for non-image extension")
}
}
+212
View File
@@ -0,0 +1,212 @@
package vision
import (
"bytes"
"fmt"
"image"
"os"
"strings"
"github.com/disintegration/imaging"
)
// ImagePayload 送入 VL API 的图片字节与 MIME。
type ImagePayload struct {
Bytes []byte
MIMEType string
}
// PreprocessMeta 记录缩放与编码结果,供工具输出与排障。
type PreprocessMeta struct {
OriginalPath string
OriginalBytes int64
OriginalWidth int
OriginalHeight int
OutputWidth int
OutputHeight int
OutputBytes int
OutputMIMEType string
JPEGQuality int // 0 表示未 JPEG 重编码(原图直传)
PreprocessMode string // passthrough | jpeg
}
// PreprocessOptions 图片预处理参数。
type PreprocessOptions struct {
MaxImageBytes int64
MaxDimension int
JPEGQuality int
MaxPayloadBytes int64
SkipPreprocessBelowBytes int64 // 0 = 始终压缩;>0 时小图+尺寸合规可直传
}
// PreprocessImageFile 读取图片;大图或超尺寸走 imaging 缩放+JPEG,否则可原图直传。
func PreprocessImageFile(path string, opt PreprocessOptions) (ImagePayload, PreprocessMeta, error) {
var meta PreprocessMeta
meta.OriginalPath = path
st, err := os.Stat(path)
if err != nil {
return ImagePayload{}, meta, err
}
meta.OriginalBytes = st.Size()
if opt.MaxImageBytes > 0 && st.Size() > opt.MaxImageBytes {
return ImagePayload{}, meta, fmt.Errorf("file size %d exceeds max_image_bytes %d", st.Size(), opt.MaxImageBytes)
}
cfgW, cfgH, format, err := imageDimensions(path)
if err != nil {
return ImagePayload{}, meta, err
}
meta.OriginalWidth = cfgW
meta.OriginalHeight = cfgH
maxDim := opt.MaxDimension
if maxDim <= 0 {
maxDim = 2048
}
maxPayload := opt.MaxPayloadBytes
if maxPayload <= 0 {
maxPayload = 512 * 1024
}
if payload, meta, ok, err := tryPassthrough(path, st.Size(), cfgW, cfgH, format, opt, maxDim, maxPayload); ok {
return payload, meta, err
}
return compressWithImaging(path, opt, maxDim, maxPayload, meta)
}
func tryPassthrough(path string, size int64, w, h int, format string, opt PreprocessOptions, maxDim int, maxPayload int64) (ImagePayload, PreprocessMeta, bool, error) {
var meta PreprocessMeta
meta.OriginalPath = path
meta.OriginalBytes = size
meta.OriginalWidth = w
meta.OriginalHeight = h
threshold := opt.SkipPreprocessBelowBytes
if threshold <= 0 {
return ImagePayload{}, meta, false, nil
}
if size > threshold {
return ImagePayload{}, meta, false, nil
}
longEdge := w
if h > longEdge {
longEdge = h
}
if longEdge > maxDim {
return ImagePayload{}, meta, false, nil
}
if size > maxPayload {
return ImagePayload{}, meta, false, nil
}
raw, err := os.ReadFile(path)
if err != nil {
return ImagePayload{}, meta, false, err
}
mime := mimeFromImageFormat(format)
if mime == "" {
return ImagePayload{}, meta, false, nil
}
meta.OutputWidth = w
meta.OutputHeight = h
meta.OutputBytes = len(raw)
meta.OutputMIMEType = mime
meta.PreprocessMode = "passthrough"
return ImagePayload{Bytes: raw, MIMEType: mime}, meta, true, nil
}
func compressWithImaging(path string, opt PreprocessOptions, maxDim int, maxPayload int64, meta PreprocessMeta) (ImagePayload, PreprocessMeta, error) {
src, err := imaging.Open(path)
if err != nil {
return ImagePayload{}, meta, fmt.Errorf("open image: %w", err)
}
bounds := src.Bounds()
meta.OriginalWidth = bounds.Dx()
meta.OriginalHeight = bounds.Dy()
dst := imaging.Fit(src, maxDim, maxDim, imaging.Lanczos)
outBounds := dst.Bounds()
meta.OutputWidth = outBounds.Dx()
meta.OutputHeight = outBounds.Dy()
quality := opt.JPEGQuality
if quality <= 0 || quality > 100 {
quality = 82
}
dim := maxDim
for attempt := 0; attempt < 6; attempt++ {
if attempt > 0 {
dim = int(float64(dim) * 0.85)
if dim < 256 {
dim = 256
}
dst = imaging.Fit(src, dim, dim, imaging.Lanczos)
outBounds = dst.Bounds()
meta.OutputWidth = outBounds.Dx()
meta.OutputHeight = outBounds.Dy()
}
q := quality
for q >= 60 {
var buf bytes.Buffer
if err := imaging.Encode(&buf, dst, imaging.JPEG, imaging.JPEGQuality(q)); err != nil {
return ImagePayload{}, meta, fmt.Errorf("encode jpeg: %w", err)
}
if int64(buf.Len()) <= maxPayload {
meta.JPEGQuality = q
meta.OutputBytes = buf.Len()
meta.OutputMIMEType = "image/jpeg"
meta.PreprocessMode = "jpeg"
return ImagePayload{Bytes: buf.Bytes(), MIMEType: "image/jpeg"}, meta, nil
}
q -= 5
}
quality = 75
}
return ImagePayload{}, meta, fmt.Errorf("could not compress image under max_payload_bytes %d", maxPayload)
}
func imageDimensions(path string) (w, h int, format string, err error) {
f, err := os.Open(path)
if err != nil {
return 0, 0, "", err
}
defer f.Close()
cfg, format, err := image.DecodeConfig(f)
if err != nil {
return 0, 0, "", fmt.Errorf("decode image config: %w", err)
}
return cfg.Width, cfg.Height, format, nil
}
func mimeFromImageFormat(format string) string {
switch strings.ToLower(strings.TrimSpace(format)) {
case "jpeg", "jpg":
return "image/jpeg"
case "png":
return "image/png"
case "gif":
return "image/gif"
case "webp":
return "image/webp"
case "bmp":
return "image/bmp"
case "tiff":
return "image/tiff"
default:
return ""
}
}
// DecodeImageConfig 用于测试:确认文件可被解码。
func DecodeImageConfig(path string) (image.Config, string, error) {
f, err := os.Open(path)
if err != nil {
return image.Config{}, "", err
}
defer f.Close()
return image.DecodeConfig(f)
}
+109
View File
@@ -0,0 +1,109 @@
package vision
import (
"image"
"image/color"
"image/png"
"os"
"path/filepath"
"testing"
"github.com/disintegration/imaging"
)
func TestPreprocessImageFile_scalesAndLimitsPayload(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "big.png")
img := imaging.New(3000, 2000, color.White)
if err := imaging.Save(img, path); err != nil {
t.Fatal(err)
}
out, meta, err := PreprocessImageFile(path, PreprocessOptions{
MaxImageBytes: 10 * 1024 * 1024,
MaxDimension: 1024,
JPEGQuality: 85,
MaxPayloadBytes: 600 * 1024,
SkipPreprocessBelowBytes: 0,
})
if err != nil {
t.Fatal(err)
}
if len(out.Bytes) == 0 {
t.Fatal("empty output")
}
if meta.PreprocessMode != "jpeg" {
t.Fatalf("mode: %s", meta.PreprocessMode)
}
if meta.OutputWidth > 1024 || meta.OutputHeight > 1024 {
t.Fatalf("expected fit within 1024, got %dx%d", meta.OutputWidth, meta.OutputHeight)
}
if int64(len(out.Bytes)) > 600*1024 {
t.Fatalf("payload %d exceeds max", len(out.Bytes))
}
}
func TestPreprocessImageFile_passthroughSmallPNG(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "small.png")
if err := imaging.Save(imaging.New(400, 300, color.White), path); err != nil {
t.Fatal(err)
}
out, meta, err := PreprocessImageFile(path, PreprocessOptions{
MaxImageBytes: 5 * 1024 * 1024,
MaxDimension: 2048,
MaxPayloadBytes: 512 * 1024,
SkipPreprocessBelowBytes: 2 * 1024 * 1024,
})
if err != nil {
t.Fatal(err)
}
if meta.PreprocessMode != "passthrough" {
t.Fatalf("expected passthrough, got %s", meta.PreprocessMode)
}
if out.MIMEType != "image/png" {
t.Fatalf("mime: %s", out.MIMEType)
}
if meta.OutputWidth != 400 || meta.OutputHeight != 300 {
t.Fatalf("dims: %dx%d", meta.OutputWidth, meta.OutputHeight)
}
}
func TestPreprocessImageFile_passthroughDisabled(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "small.png")
if err := imaging.Save(imaging.New(100, 100, color.White), path); err != nil {
t.Fatal(err)
}
_, meta, err := PreprocessImageFile(path, PreprocessOptions{
MaxDimension: 2048,
MaxPayloadBytes: 512 * 1024,
SkipPreprocessBelowBytes: 0,
})
if err != nil {
t.Fatal(err)
}
if meta.PreprocessMode != "jpeg" {
t.Fatalf("expected jpeg compress, got %s", meta.PreprocessMode)
}
}
func TestPreprocessImageFile_rejectsOversizeFile(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "tiny.png")
f, err := os.Create(path)
if err != nil {
t.Fatal(err)
}
if err := png.Encode(f, image.NewRGBA(image.Rect(0, 0, 2, 2))); err != nil {
t.Fatal(err)
}
f.Close()
_, _, err = PreprocessImageFile(path, PreprocessOptions{MaxImageBytes: 1})
if err == nil {
t.Fatal("expected error when file exceeds max_image_bytes")
}
}
+130
View File
@@ -0,0 +1,130 @@
package vision
import (
"context"
"fmt"
"os"
"strings"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin"
"go.uber.org/zap"
)
// RegisterAnalyzeImageTool 在 vision.enabled 且 model 已配置时注册 MCP 工具 analyze_image。
func RegisterAnalyzeImageTool(mcpServer *mcp.Server, cfg *config.Config, logger *zap.Logger) {
if mcpServer == nil || cfg == nil {
return
}
if !cfg.Vision.Ready() {
if cfg.Vision.Enabled && logger != nil {
logger.Warn("vision.enabled 但 vision.model 为空,跳过注册 analyze_image")
}
return
}
cwd, err := os.Getwd()
if err != nil {
if logger != nil {
logger.Warn("vision: getwd failed, skip analyze_image", zap.Error(err))
}
return
}
pathOpt := PathOptions{
CWD: cwd,
ResultStorageDir: cfg.Agent.ResultStorageDir,
ExtraRoots: cfg.Vision.AllowedRoots,
}
preOpt := PreprocessOptions{
MaxImageBytes: cfg.Vision.MaxImageBytesEffective(),
MaxDimension: cfg.Vision.MaxDimensionEffective(),
JPEGQuality: cfg.Vision.JPEGQualityEffective(),
MaxPayloadBytes: cfg.Vision.MaxPayloadBytesEffective(),
SkipPreprocessBelowBytes: cfg.Vision.SkipPreprocessBelowBytesEffective(),
}
client := NewClient(cfg.Vision, cfg.OpenAI)
tool := mcp.Tool{
Name: builtin.ToolAnalyzeImage,
Description: "分析服务器上的本地图片并返回文字描述(验证码、UI 元素、报错、架构图要点等)。" +
"输入为文件路径(如用户上传的 chat_uploads 路径或工具截图路径)。" +
"输出仅为文本,不含图片数据。不要对二进制图片使用 read_file 指望理解内容。",
ShortDescription: "分析本地图片并返回文字描述(验证码/UI/报错等)",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"path": map[string]interface{}{
"type": "string",
"description": "图片绝对路径或相对于进程工作目录的路径",
},
"question": map[string]interface{}{
"type": "string",
"description": "可选:希望模型重点回答的问题。验证码图建议:只输出验证码字符,不要空格和解释",
},
},
"required": []string{"path"},
},
}
handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
path, _ := args["path"].(string)
question, _ := args["question"].(string)
abs, err := ResolveImagePath(path, pathOpt)
if err != nil {
return textResult(fmt.Sprintf("路径校验失败: %v", err), true), nil
}
img, meta, err := PreprocessImageFile(abs, preOpt)
if err != nil {
return textResult(fmt.Sprintf("图片预处理失败: %v", err), true), nil
}
summary, err := client.Analyze(ctx, img, question)
if err != nil {
return textResult(fmt.Sprintf("视觉模型调用失败: %v", err), true), nil
}
body := formatAnalysisResult(abs, meta, summary)
return textResult(body, false), nil
}
mcpServer.RegisterTool(tool, handler)
if logger != nil {
logger.Info("vision: analyze_image 工具已注册", zap.String("model", cfg.Vision.Model))
}
}
func textResult(text string, isError bool) *mcp.ToolResult {
return &mcp.ToolResult{
Content: []mcp.Content{{Type: "text", Text: text}},
IsError: isError,
}
}
func formatAnalysisResult(path string, meta PreprocessMeta, summary string) string {
var b strings.Builder
b.WriteString("## Image analysis\n")
b.WriteString("- **path**: ")
b.WriteString(path)
b.WriteString("\n")
switch meta.PreprocessMode {
case "passthrough":
b.WriteString(fmt.Sprintf("- **preprocess**: passthrough %dx%d, %s, %dKB (original %dKB)\n\n",
meta.OutputWidth, meta.OutputHeight, meta.OutputMIMEType,
(meta.OutputBytes+1023)/1024, (meta.OriginalBytes+1023)/1024))
default:
b.WriteString(fmt.Sprintf("- **preprocess**: %dx%d → %dx%d, jpeg q=%d, %dKB (original %dKB)\n\n",
meta.OriginalWidth, meta.OriginalHeight,
meta.OutputWidth, meta.OutputHeight,
meta.JPEGQuality, (meta.OutputBytes+1023)/1024,
(meta.OriginalBytes+1023)/1024))
}
b.WriteString("### Summary\n")
b.WriteString(strings.TrimSpace(summary))
b.WriteString("\n")
return b.String()
}