commit add33e1cf72209938793f6a793dbffa35825e0bc
Author: 公明 <83812544+Ed1s0nZ@users.noreply.github.com>
Date: Sat Nov 8 18:56:23 2025 +0800
Add files via upload
diff --git a/README.md b/README.md
new file mode 100644
index 00000000..479312ce
--- /dev/null
+++ b/README.md
@@ -0,0 +1,319 @@
+# CyberStrikeAI
+
+基于Golang和Gin框架的AI驱动自主渗透测试平台,使用MCP协议集成安全工具。
+
+## 功能特性
+
+- 🤖 **AI代理连接** - 支持Claude、GPT等兼容MCP的AI代理通过FastMCP协议连接
+- 🧠 **智能分析** - 决策引擎分析目标并选择最佳测试策略
+- ⚡ **自主执行** - AI代理执行全面的安全评估
+- 🔄 **实时适应** - 系统根据结果和发现的漏洞进行调整
+- 📊 **高级报告** - 可视化方式输出漏洞卡片和风险分析
+- 💬 **对话式交互** - 前端以对话形式调用后端agent-loop
+- 📈 **实时监控** - 监控安全工具的执行状态、结果、调用次数等
+
+## 项目结构
+
+```
+CyberStrikeAI/
+├── cmd/
+│ └── server/
+│ └── main.go # 程序入口
+├── internal/
+│ ├── agent/ # AI代理模块
+│ ├── app/ # 应用初始化
+│ ├── config/ # 配置管理
+│ ├── handler/ # HTTP处理器
+│ ├── logger/ # 日志系统
+│ ├── mcp/ # MCP协议实现
+│ └── security/ # 安全工具执行器
+├── web/
+│ ├── static/ # 静态资源
+│ │ ├── css/
+│ │ └── js/
+│ └── templates/ # HTML模板
+├── config.yaml # 配置文件
+├── go.mod # Go模块文件
+└── README.md # 说明文档
+```
+
+## 快速开始
+
+### 前置要求
+
+- Go 1.21 或更高版本
+- OpenAI API Key(或其他兼容OpenAI协议的API)
+- 安全工具(可选):nmap, sqlmap, nikto, dirb
+
+### 安装步骤
+
+1. **克隆项目**
+```bash
+cd /Users/temp/Desktop/wenjian/tools/CyberStrikeAI
+```
+
+2. **安装依赖**
+```bash
+go mod download
+```
+
+3. **配置**
+编辑 `config.yaml` 文件,设置您的OpenAI API Key:
+```yaml
+openai:
+ api_key: "sk-your-api-key-here"
+ base_url: "https://api.openai.com/v1"
+ model: "gpt-4"
+```
+
+4. **启动服务器**
+
+#### 方式一:使用启动脚本
+```bash
+./run.sh
+```
+
+#### 方式二:直接运行
+```bash
+go run cmd/server/main.go
+```
+
+#### 方式三:编译后运行
+```bash
+go build -o cyberstrike-ai cmd/server/main.go
+./cyberstrike-ai
+```
+
+5. **访问应用**
+打开浏览器访问:http://localhost:8080
+
+## 配置说明
+
+### 服务器配置
+```yaml
+server:
+ host: "0.0.0.0"
+ port: 8080
+```
+
+### MCP配置
+```yaml
+mcp:
+ enabled: true
+ host: "0.0.0.0"
+ port: 8081
+```
+
+### 安全工具配置
+```yaml
+security:
+ tools:
+ - name: "nmap"
+ command: "nmap"
+ args: ["-sV", "-sC"]
+ description: "网络扫描工具"
+ enabled: true
+```
+
+## 使用示例
+
+### 对话式渗透测试
+
+在"对话测试"标签页中,您可以:
+
+1. **网络扫描**
+ ```
+ 扫描 192.168.1.1 的开放端口
+ ```
+
+2. **SQL注入检测**
+ ```
+ 检测 https://example.com 的SQL注入漏洞
+ ```
+
+3. **Web漏洞扫描**
+ ```
+ 扫描 https://example.com 的Web服务器漏洞
+ ```
+
+4. **目录扫描**
+ ```
+ 扫描 https://example.com 的隐藏目录
+ ```
+
+### 监控工具执行
+
+在"工具监控"标签页中,您可以:
+
+- 查看所有工具的执行统计
+- 查看详细的执行记录
+- 查看发现的漏洞列表
+- 实时监控工具状态
+
+## API接口
+
+### Agent Loop API
+
+**POST** `/api/agent-loop`
+
+请求体:
+```json
+{
+ "message": "扫描 192.168.1.1"
+}
+```
+
+使用示例:
+```bash
+curl -X POST http://localhost:8080/api/agent-loop \
+ -H "Content-Type: application/json" \
+ -d '{"message": "扫描 192.168.1.1"}'
+```
+
+### 监控API
+
+- **GET** `/api/monitor` - 获取所有监控信息
+- **GET** `/api/monitor/execution/:id` - 获取特定执行记录
+- **GET** `/api/monitor/stats` - 获取统计信息
+- **GET** `/api/monitor/vulnerabilities` - 获取漏洞列表
+
+使用示例:
+```bash
+# 获取所有监控信息
+curl http://localhost:8080/api/monitor
+
+# 获取统计信息
+curl http://localhost:8080/api/monitor/stats
+
+# 获取漏洞列表
+curl http://localhost:8080/api/monitor/vulnerabilities
+```
+
+### MCP接口
+
+**POST** `/api/mcp` - MCP协议端点
+
+## MCP协议
+
+本项目实现了MCP(Model Context Protocol)协议,支持:
+
+- `initialize` - 初始化连接
+- `tools/list` - 列出可用工具
+- `tools/call` - 调用工具
+
+工具调用是异步执行的,系统会跟踪每个工具的执行状态和结果。
+
+### MCP协议使用示例
+
+#### 初始化连接
+
+```bash
+curl -X POST http://localhost:8080/api/mcp \
+ -H "Content-Type: application/json" \
+ -d '{
+ "jsonrpc": "2.0",
+ "id": "1",
+ "method": "initialize",
+ "params": {
+ "protocolVersion": "2024-11-05",
+ "capabilities": {},
+ "clientInfo": {
+ "name": "test-client",
+ "version": "1.0.0"
+ }
+ }
+ }'
+```
+
+#### 列出工具
+
+```bash
+curl -X POST http://localhost:8080/api/mcp \
+ -H "Content-Type: application/json" \
+ -d '{
+ "jsonrpc": "2.0",
+ "id": "2",
+ "method": "tools/list"
+ }'
+```
+
+#### 调用工具
+
+```bash
+curl -X POST http://localhost:8080/api/mcp \
+ -H "Content-Type: application/json" \
+ -d '{
+ "jsonrpc": "2.0",
+ "id": "3",
+ "method": "tools/call",
+ "params": {
+ "name": "nmap",
+ "arguments": {
+ "target": "192.168.1.1",
+ "ports": "1-1000"
+ }
+ }
+ }'
+```
+
+## 安全工具支持
+
+当前支持的安全工具:
+- **nmap** - 网络扫描
+- **sqlmap** - SQL注入检测
+- **nikto** - Web服务器扫描
+- **dirb** - 目录扫描
+
+可以通过修改 `config.yaml` 添加更多工具。
+
+## 故障排除
+
+### 问题:无法连接到OpenAI API
+
+- 检查API Key是否正确
+- 检查网络连接
+- 检查base_url配置
+
+### 问题:工具执行失败
+
+- 确保已安装相应的安全工具(nmap, sqlmap等)
+- 检查工具是否在PATH中
+- 某些工具可能需要root权限
+
+### 问题:前端无法加载
+
+- 检查服务器是否正常运行
+- 检查端口8080是否被占用
+- 查看浏览器控制台错误信息
+
+## 安全注意事项
+
+⚠️ **重要提示**:
+
+- 仅对您拥有或已获得授权的系统进行测试
+- 遵守相关法律法规
+- 建议在隔离的测试环境中使用
+- 不要在生产环境中使用
+- 某些安全工具可能需要root权限
+
+## 开发
+
+### 添加新工具
+
+1. 在 `config.yaml` 中添加工具配置
+2. 在 `internal/security/executor.go` 的 `buildCommandArgs` 方法中添加参数构建逻辑
+3. 在 `internal/agent/agent.go` 的 `getAvailableTools` 方法中添加工具定义
+
+### 构建
+
+```bash
+go build -o cyberstrike-ai cmd/server/main.go
+```
+
+## 许可证
+
+本项目仅供学习和研究使用。
+
+## 贡献
+
+欢迎提交Issue和Pull Request!
diff --git a/cmd/server/main.go b/cmd/server/main.go
new file mode 100644
index 00000000..cb4292bd
--- /dev/null
+++ b/cmd/server/main.go
@@ -0,0 +1,36 @@
+package main
+
+import (
+ "cyberstrike-ai/internal/app"
+ "cyberstrike-ai/internal/config"
+ "cyberstrike-ai/internal/logger"
+ "flag"
+ "fmt"
+)
+
+func main() {
+ var configPath = flag.String("config", "config.yaml", "配置文件路径")
+ flag.Parse()
+
+ // 加载配置
+ cfg, err := config.Load(*configPath)
+ if err != nil {
+ fmt.Printf("加载配置失败: %v\n", err)
+ return
+ }
+
+ // 初始化日志
+ log := logger.New(cfg.Log.Level, cfg.Log.Output)
+
+ // 创建应用
+ application, err := app.New(cfg, log)
+ if err != nil {
+ log.Fatal("应用初始化失败", "error", err)
+ }
+
+ // 启动服务器
+ if err := application.Run(); err != nil {
+ log.Fatal("服务器启动失败", "error", err)
+ }
+}
+
diff --git a/config.yaml b/config.yaml
new file mode 100644
index 00000000..fa2672db
--- /dev/null
+++ b/config.yaml
@@ -0,0 +1,174 @@
+server:
+ host: "0.0.0.0"
+ port: 8080
+
+log:
+ level: "info"
+ output: "stdout"
+
+mcp:
+ enabled: true
+ host: "0.0.0.0"
+ port: 8081
+
+openai:
+ api_key: "sk-xxx" # 请设置您的OpenAI API Key
+ base_url: "https://api.deepseek.com/v1"
+ model: "deepseek-chat"
+
+database:
+ path: "data/conversations.db"
+
+security:
+ tools:
+ # 示例1: 使用参数定义的工具(推荐方式)
+ - name: "nmap"
+ command: "nmap"
+ args: ["-sT", "-sV", "-sC"] # 固定参数
+ description: "网络扫描工具,用于发现网络主机和服务"
+ enabled: true
+ parameters:
+ - name: "target"
+ type: "string"
+ description: "目标IP地址或域名"
+ required: true
+ position: 0 # 位置参数,放在最后
+ format: "positional"
+ - name: "ports"
+ type: "string"
+ description: "端口范围,例如: 1-1000, 80,443,8080"
+ required: false
+ flag: "-p"
+ format: "flag"
+
+ # 示例2: 标志参数工具
+ - name: "sqlmap"
+ command: "sqlmap"
+ description: "SQL注入检测和利用工具"
+ enabled: true
+ parameters:
+ - name: "url"
+ type: "string"
+ description: "目标URL,例如: http://example.com/page?id=1"
+ required: true
+ flag: "-u"
+ format: "flag"
+ - name: "batch"
+ type: "bool"
+ description: "非交互模式"
+ required: false
+ default: true
+ flag: "--batch"
+ format: "flag"
+ - name: "level"
+ type: "int"
+ description: "测试级别 (1-5)"
+ required: false
+ default: 3
+ flag: "--level"
+ format: "combined" # --level=3
+
+ # 示例3: 位置参数工具
+ - name: "nikto"
+ command: "nikto"
+ description: "Web服务器扫描工具"
+ enabled: true
+ parameters:
+ - name: "target"
+ type: "string"
+ description: "目标URL或IP地址"
+ required: true
+ flag: "-h"
+ format: "flag"
+
+ # 示例4: 简单位置参数
+ - name: "dirb"
+ command: "dirb"
+ description: "Web目录扫描工具"
+ enabled: true
+ parameters:
+ - name: "url"
+ type: "string"
+ description: "目标URL"
+ required: true
+ position: 0
+ format: "positional"
+ - name: "wordlist"
+ type: "string"
+ description: "字典文件路径"
+ required: false
+ flag: "-w"
+ format: "flag"
+
+ # 示例5: 执行系统命令
+ - name: "exec"
+ command: "sh"
+ args: ["-c"]
+ description: "执行系统命令(谨慎使用)"
+ enabled: true
+ parameters:
+ - name: "command"
+ type: "string"
+ description: "要执行的系统命令"
+ required: true
+ position: 0
+ format: "positional"
+ - name: "shell"
+ type: "string"
+ description: "使用的shell(可选,默认为sh)"
+ required: false
+ default: "sh"
+ - name: "workdir"
+ type: "string"
+ description: "工作目录"
+ required: false
+
+ # 示例6: 自定义工具 - 使用模板格式
+ - name: "custom_scanner"
+ command: "my-scanner"
+ description: "自定义扫描工具示例"
+ enabled: false # 默认禁用,需要时启用
+ parameters:
+ - name: "target"
+ type: "string"
+ description: "扫描目标"
+ required: true
+ flag: "--target"
+ format: "flag"
+ - name: "mode"
+ type: "string"
+ description: "扫描模式"
+ required: false
+ default: "normal"
+ options: ["normal", "aggressive", "stealth"] # 枚举值
+ flag: "--mode"
+ format: "combined" # --mode=normal
+ - name: "threads"
+ type: "int"
+ description: "线程数"
+ required: false
+ default: 10
+ flag: "-t"
+ format: "flag"
+ - name: "output"
+ type: "string"
+ description: "输出文件路径"
+ required: false
+ flag: "-o"
+ format: "template"
+ template: "-o {value}" # 自定义模板
+ - name: "verbose"
+ type: "bool"
+ description: "详细输出"
+ required: false
+ default: false
+ flag: "-v"
+ format: "flag" # 布尔值:如果为true,只添加-v,不添加值
+
+ # 示例7: 向后兼容 - 不定义parameters,使用旧的硬编码逻辑
+ # - name: "legacy_tool"
+ # command: "legacy"
+ # args: ["--option"]
+ # description: "旧工具(使用硬编码逻辑)"
+ # enabled: false
+
diff --git a/data/conversations.db b/data/conversations.db
new file mode 100644
index 00000000..4ebf78cf
Binary files /dev/null and b/data/conversations.db differ
diff --git a/data/conversations.db-shm b/data/conversations.db-shm
new file mode 100644
index 00000000..8dfd5497
Binary files /dev/null and b/data/conversations.db-shm differ
diff --git a/data/conversations.db-wal b/data/conversations.db-wal
new file mode 100644
index 00000000..9b8a63c9
Binary files /dev/null and b/data/conversations.db-wal differ
diff --git a/go.mod b/go.mod
new file mode 100644
index 00000000..e5233bfe
--- /dev/null
+++ b/go.mod
@@ -0,0 +1,38 @@
+module cyberstrike-ai
+
+go 1.21
+
+require (
+ github.com/gin-gonic/gin v1.9.1
+ github.com/google/uuid v1.5.0
+ github.com/mattn/go-sqlite3 v1.14.18
+ go.uber.org/zap v1.26.0
+ gopkg.in/yaml.v3 v3.0.1
+)
+
+require (
+ github.com/bytedance/sonic v1.9.1 // indirect
+ github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
+ github.com/gabriel-vasile/mimetype v1.4.2 // indirect
+ github.com/gin-contrib/sse v0.1.0 // indirect
+ github.com/go-playground/locales v0.14.1 // indirect
+ github.com/go-playground/universal-translator v0.18.1 // indirect
+ github.com/go-playground/validator/v10 v10.14.0 // indirect
+ github.com/goccy/go-json v0.10.2 // indirect
+ github.com/json-iterator/go v1.1.12 // indirect
+ github.com/klauspost/cpuid/v2 v2.2.4 // indirect
+ github.com/leodido/go-urn v1.2.4 // indirect
+ github.com/mattn/go-isatty v0.0.19 // indirect
+ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
+ github.com/modern-go/reflect2 v1.0.2 // indirect
+ github.com/pelletier/go-toml/v2 v2.0.8 // indirect
+ github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
+ github.com/ugorji/go/codec v1.2.11 // indirect
+ go.uber.org/multierr v1.11.0 // indirect
+ golang.org/x/arch v0.3.0 // indirect
+ golang.org/x/crypto v0.14.0 // indirect
+ golang.org/x/net v0.17.0 // indirect
+ golang.org/x/sys v0.13.0 // indirect
+ golang.org/x/text v0.13.0 // indirect
+ google.golang.org/protobuf v1.30.0 // indirect
+)
diff --git a/go.sum b/go.sum
new file mode 100644
index 00000000..2d559b00
--- /dev/null
+++ b/go.sum
@@ -0,0 +1,96 @@
+github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
+github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
+github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
+github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
+github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
+github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
+github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
+github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
+github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
+github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
+github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
+github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
+github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
+github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
+github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
+github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
+github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
+github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
+github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
+github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
+github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
+github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
+github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
+github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
+github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
+github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
+github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
+github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU=
+github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
+github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
+github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
+github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
+github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
+github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
+github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
+github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
+github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
+github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
+github.com/mattn/go-sqlite3 v1.14.18 h1:JL0eqdCOq6DJVNPSvArO/bIV9/P7fbGrV00LZHc+5aI=
+github.com/mattn/go-sqlite3 v1.14.18/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
+github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
+github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
+github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
+github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
+github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
+github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
+github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
+github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
+github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
+github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
+github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
+github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
+github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
+github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
+github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
+github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
+github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
+github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
+github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
+github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
+go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk=
+go.uber.org/goleak v1.2.0/go.mod h1:XJYK+MuIchqpmGmUSAzotztawfKvYLUIgg7guXrwVUo=
+go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
+go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
+go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo=
+go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so=
+golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
+golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
+golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
+golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
+golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
+golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
+golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
+golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
+golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
+golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
+golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
+golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
+google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
+google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
+gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
diff --git a/internal/agent/agent.go b/internal/agent/agent.go
new file mode 100644
index 00000000..7110138c
--- /dev/null
+++ b/internal/agent/agent.go
@@ -0,0 +1,576 @@
+package agent
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+ "time"
+
+ "cyberstrike-ai/internal/config"
+ "cyberstrike-ai/internal/mcp"
+ "go.uber.org/zap"
+)
+
+// Agent AI代理
+type Agent struct {
+ openAIClient *http.Client
+ config *config.OpenAIConfig
+ mcpServer *mcp.Server
+ logger *zap.Logger
+}
+
+// NewAgent 创建新的Agent
+func NewAgent(cfg *config.OpenAIConfig, mcpServer *mcp.Server, logger *zap.Logger) *Agent {
+ return &Agent{
+ openAIClient: &http.Client{Timeout: 5 * time.Minute},
+ config: cfg,
+ mcpServer: mcpServer,
+ logger: logger,
+ }
+}
+
+// ChatMessage 聊天消息
+type ChatMessage struct {
+ Role string `json:"role"`
+ Content string `json:"content,omitempty"`
+ ToolCalls []ToolCall `json:"tool_calls,omitempty"`
+ ToolCallID string `json:"tool_call_id,omitempty"`
+}
+
+// MarshalJSON 自定义JSON序列化,将tool_calls中的arguments转换为JSON字符串
+func (cm ChatMessage) MarshalJSON() ([]byte, error) {
+ // 构建序列化结构
+ aux := map[string]interface{}{
+ "role": cm.Role,
+ }
+
+ // 添加content(如果存在)
+ if cm.Content != "" {
+ aux["content"] = cm.Content
+ }
+
+ // 添加tool_call_id(如果存在)
+ if cm.ToolCallID != "" {
+ aux["tool_call_id"] = cm.ToolCallID
+ }
+
+ // 转换tool_calls,将arguments转换为JSON字符串
+ if len(cm.ToolCalls) > 0 {
+ toolCallsJSON := make([]map[string]interface{}, len(cm.ToolCalls))
+ for i, tc := range cm.ToolCalls {
+ // 将arguments转换为JSON字符串
+ argsJSON := ""
+ if tc.Function.Arguments != nil {
+ argsBytes, err := json.Marshal(tc.Function.Arguments)
+ if err != nil {
+ return nil, err
+ }
+ argsJSON = string(argsBytes)
+ }
+
+ toolCallsJSON[i] = map[string]interface{}{
+ "id": tc.ID,
+ "type": tc.Type,
+ "function": map[string]interface{}{
+ "name": tc.Function.Name,
+ "arguments": argsJSON,
+ },
+ }
+ }
+ aux["tool_calls"] = toolCallsJSON
+ }
+
+ return json.Marshal(aux)
+}
+
+// OpenAIRequest OpenAI API请求
+type OpenAIRequest struct {
+ Model string `json:"model"`
+ Messages []ChatMessage `json:"messages"`
+ Tools []Tool `json:"tools,omitempty"`
+}
+
+// OpenAIResponse OpenAI API响应
+type OpenAIResponse struct {
+ ID string `json:"id"`
+ Choices []Choice `json:"choices"`
+ Error *Error `json:"error,omitempty"`
+}
+
+// Choice 选择
+type Choice struct {
+ Message MessageWithTools `json:"message"`
+ FinishReason string `json:"finish_reason"`
+}
+
+// MessageWithTools 带工具调用的消息
+type MessageWithTools struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+ ToolCalls []ToolCall `json:"tool_calls,omitempty"`
+}
+
+// Tool OpenAI工具定义
+type Tool struct {
+ Type string `json:"type"`
+ Function FunctionDefinition `json:"function"`
+}
+
+// FunctionDefinition 函数定义
+type FunctionDefinition struct {
+ Name string `json:"name"`
+ Description string `json:"description"`
+ Parameters map[string]interface{} `json:"parameters"`
+}
+
+// Error OpenAI错误
+type Error struct {
+ Message string `json:"message"`
+ Type string `json:"type"`
+}
+
+// ToolCall 工具调用
+type ToolCall struct {
+ ID string `json:"id"`
+ Type string `json:"type"`
+ Function FunctionCall `json:"function"`
+}
+
+// FunctionCall 函数调用
+type FunctionCall struct {
+ Name string `json:"name"`
+ Arguments map[string]interface{} `json:"arguments"`
+}
+
+// UnmarshalJSON 自定义JSON解析,处理arguments可能是字符串或对象的情况
+func (fc *FunctionCall) UnmarshalJSON(data []byte) error {
+ type Alias FunctionCall
+ aux := &struct {
+ Name string `json:"name"`
+ Arguments interface{} `json:"arguments"`
+ *Alias
+ }{
+ Alias: (*Alias)(fc),
+ }
+
+ if err := json.Unmarshal(data, &aux); err != nil {
+ return err
+ }
+
+ fc.Name = aux.Name
+
+ // 处理arguments可能是字符串或对象的情况
+ switch v := aux.Arguments.(type) {
+ case map[string]interface{}:
+ fc.Arguments = v
+ case string:
+ // 如果是字符串,尝试解析为JSON
+ if err := json.Unmarshal([]byte(v), &fc.Arguments); err != nil {
+ // 如果解析失败,创建一个包含原始字符串的map
+ fc.Arguments = map[string]interface{}{
+ "raw": v,
+ }
+ }
+ case nil:
+ fc.Arguments = make(map[string]interface{})
+ default:
+ // 其他类型,尝试转换为map
+ fc.Arguments = map[string]interface{}{
+ "value": v,
+ }
+ }
+
+ return nil
+}
+
+// AgentLoopResult Agent Loop执行结果
+type AgentLoopResult struct {
+ Response string
+ MCPExecutionIDs []string
+}
+
+// AgentLoop 执行Agent循环
+func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages []ChatMessage) (*AgentLoopResult, error) {
+ messages := []ChatMessage{
+ {
+ Role: "system",
+ Content: "你是一个专业的网络安全渗透测试专家。你可以使用各种安全工具进行自主渗透测试。分析目标并选择最佳测试策略。当需要执行工具时,使用提供的工具函数。",
+ },
+ }
+
+ // 添加历史消息(数据库只保存user和assistant消息)
+ a.logger.Info("处理历史消息",
+ zap.Int("count", len(historyMessages)),
+ )
+ addedCount := 0
+ for i, msg := range historyMessages {
+ // 只添加有内容的消息
+ if msg.Content != "" {
+ messages = append(messages, ChatMessage{
+ Role: msg.Role,
+ Content: msg.Content,
+ })
+ addedCount++
+ contentPreview := msg.Content
+ if len(contentPreview) > 50 {
+ contentPreview = contentPreview[:50] + "..."
+ }
+ a.logger.Info("添加历史消息到上下文",
+ zap.Int("index", i),
+ zap.String("role", msg.Role),
+ zap.String("content", contentPreview),
+ )
+ }
+ }
+
+ a.logger.Info("构建消息数组",
+ zap.Int("historyMessages", len(historyMessages)),
+ zap.Int("addedMessages", addedCount),
+ zap.Int("totalMessages", len(messages)),
+ )
+
+ // 添加当前用户消息
+ messages = append(messages, ChatMessage{
+ Role: "user",
+ Content: userInput,
+ })
+
+ result := &AgentLoopResult{
+ MCPExecutionIDs: make([]string, 0),
+ }
+
+ maxIterations := 10
+ for i := 0; i < maxIterations; i++ {
+ // 获取可用工具
+ tools := a.getAvailableTools()
+
+ // 记录每次调用OpenAI
+ if i == 0 {
+ a.logger.Info("调用OpenAI",
+ zap.Int("iteration", i+1),
+ zap.Int("messagesCount", len(messages)),
+ )
+ // 记录前几条消息的内容(用于调试)
+ for j, msg := range messages {
+ if j >= 5 { // 只记录前5条
+ break
+ }
+ contentPreview := msg.Content
+ if len(contentPreview) > 100 {
+ contentPreview = contentPreview[:100] + "..."
+ }
+ a.logger.Debug("消息内容",
+ zap.Int("index", j),
+ zap.String("role", msg.Role),
+ zap.String("content", contentPreview),
+ )
+ }
+ } else {
+ a.logger.Info("调用OpenAI",
+ zap.Int("iteration", i+1),
+ zap.Int("messagesCount", len(messages)),
+ )
+ }
+
+ // 调用OpenAI
+ response, err := a.callOpenAI(ctx, messages, tools)
+ if err != nil {
+ result.Response = ""
+ return result, fmt.Errorf("调用OpenAI失败: %w", err)
+ }
+
+ if response.Error != nil {
+ result.Response = ""
+ return result, fmt.Errorf("OpenAI错误: %s", response.Error.Message)
+ }
+
+ if len(response.Choices) == 0 {
+ result.Response = ""
+ return result, fmt.Errorf("没有收到响应")
+ }
+
+ choice := response.Choices[0]
+
+ // 检查是否有工具调用
+ if len(choice.Message.ToolCalls) > 0 {
+ // 添加assistant消息(包含工具调用)
+ messages = append(messages, ChatMessage{
+ Role: "assistant",
+ Content: choice.Message.Content,
+ ToolCalls: choice.Message.ToolCalls,
+ })
+
+ // 执行所有工具调用
+ for _, toolCall := range choice.Message.ToolCalls {
+ // 执行工具
+ execResult, err := a.executeToolViaMCP(ctx, toolCall.Function.Name, toolCall.Function.Arguments)
+ if err != nil {
+ messages = append(messages, ChatMessage{
+ Role: "tool",
+ ToolCallID: toolCall.ID,
+ Content: fmt.Sprintf("工具执行失败: %v", err),
+ })
+ } else {
+ messages = append(messages, ChatMessage{
+ Role: "tool",
+ ToolCallID: toolCall.ID,
+ Content: execResult.Result,
+ })
+ // 收集执行ID
+ if execResult.ExecutionID != "" {
+ result.MCPExecutionIDs = append(result.MCPExecutionIDs, execResult.ExecutionID)
+ }
+ }
+ }
+ continue
+ }
+
+ // 添加assistant响应
+ messages = append(messages, ChatMessage{
+ Role: "assistant",
+ Content: choice.Message.Content,
+ })
+
+ // 如果完成,返回结果
+ if choice.FinishReason == "stop" {
+ result.Response = choice.Message.Content
+ return result, nil
+ }
+ }
+
+ result.Response = "达到最大迭代次数"
+ return result, nil
+}
+
+// getAvailableTools 获取可用工具
+func (a *Agent) getAvailableTools() []Tool {
+ // 从MCP服务器获取工具列表
+ executions := a.mcpServer.GetAllExecutions()
+ toolNames := make(map[string]bool)
+ for _, exec := range executions {
+ toolNames[exec.ToolName] = true
+ }
+
+ tools := []Tool{
+ {
+ Type: "function",
+ Function: FunctionDefinition{
+ Name: "nmap",
+ Description: "使用nmap进行网络扫描,发现开放端口和服务。支持IP地址、域名或URL(会自动提取域名)。使用TCP连接扫描,不需要root权限。",
+ Parameters: map[string]interface{}{
+ "type": "object",
+ "properties": map[string]interface{}{
+ "target": map[string]interface{}{
+ "type": "string",
+ "description": "目标IP地址、域名或URL(如 https://example.com)。如果是URL,会自动提取域名部分。",
+ },
+ "ports": map[string]interface{}{
+ "type": "string",
+ "description": "要扫描的端口范围,例如: 1-1000 或 80,443,8080。如果不指定,将扫描常用端口。",
+ },
+ },
+ "required": []string{"target"},
+ },
+ },
+ },
+ {
+ Type: "function",
+ Function: FunctionDefinition{
+ Name: "sqlmap",
+ Description: "使用sqlmap检测SQL注入漏洞",
+ Parameters: map[string]interface{}{
+ "type": "object",
+ "properties": map[string]interface{}{
+ "url": map[string]interface{}{
+ "type": "string",
+ "description": "目标URL",
+ },
+ },
+ "required": []string{"url"},
+ },
+ },
+ },
+ {
+ Type: "function",
+ Function: FunctionDefinition{
+ Name: "nikto",
+ Description: "使用nikto扫描Web服务器漏洞",
+ Parameters: map[string]interface{}{
+ "type": "object",
+ "properties": map[string]interface{}{
+ "target": map[string]interface{}{
+ "type": "string",
+ "description": "目标URL",
+ },
+ },
+ "required": []string{"target"},
+ },
+ },
+ },
+ {
+ Type: "function",
+ Function: FunctionDefinition{
+ Name: "dirb",
+ Description: "使用dirb进行目录扫描",
+ Parameters: map[string]interface{}{
+ "type": "object",
+ "properties": map[string]interface{}{
+ "url": map[string]interface{}{
+ "type": "string",
+ "description": "目标URL",
+ },
+ },
+ "required": []string{"url"},
+ },
+ },
+ },
+ {
+ Type: "function",
+ Function: FunctionDefinition{
+ Name: "exec",
+ Description: "执行系统命令(谨慎使用,仅用于必要的系统操作)",
+ Parameters: map[string]interface{}{
+ "type": "object",
+ "properties": map[string]interface{}{
+ "command": map[string]interface{}{
+ "type": "string",
+ "description": "要执行的系统命令",
+ },
+ "shell": map[string]interface{}{
+ "type": "string",
+ "description": "使用的shell(可选,默认为sh)",
+ },
+ "workdir": map[string]interface{}{
+ "type": "string",
+ "description": "工作目录(可选)",
+ },
+ },
+ "required": []string{"command"},
+ },
+ },
+ },
+ }
+
+ return tools
+}
+
+// callOpenAI 调用OpenAI API
+func (a *Agent) callOpenAI(ctx context.Context, messages []ChatMessage, tools []Tool) (*OpenAIResponse, error) {
+ reqBody := OpenAIRequest{
+ Model: a.config.Model,
+ Messages: messages,
+ }
+
+ if len(tools) > 0 {
+ reqBody.Tools = tools
+ }
+
+ jsonData, err := json.Marshal(reqBody)
+ if err != nil {
+ return nil, err
+ }
+
+ req, err := http.NewRequestWithContext(ctx, "POST", a.config.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
+ if err != nil {
+ return nil, err
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Authorization", "Bearer "+a.config.APIKey)
+
+ resp, err := a.openAIClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, err
+ }
+
+ // 记录响应内容(用于调试)
+ if resp.StatusCode != http.StatusOK {
+ a.logger.Warn("OpenAI API返回非200状态码",
+ zap.Int("status", resp.StatusCode),
+ zap.String("body", string(body)),
+ )
+ }
+
+ var response OpenAIResponse
+ if err := json.Unmarshal(body, &response); err != nil {
+ a.logger.Error("解析OpenAI响应失败",
+ zap.Error(err),
+ zap.String("body", string(body)),
+ )
+ return nil, fmt.Errorf("解析响应失败: %w, 响应内容: %s", err, string(body))
+ }
+
+ return &response, nil
+}
+
+// parseToolCall 解析工具调用
+func (a *Agent) parseToolCall(content string) (map[string]interface{}, error) {
+ // 简单解析,实际应该更复杂
+ // 格式: [TOOL_CALL]tool_name:arg1=value1,arg2=value2
+ if !strings.HasPrefix(content, "[TOOL_CALL]") {
+ return nil, fmt.Errorf("不是有效的工具调用格式")
+ }
+
+ parts := strings.Split(content[len("[TOOL_CALL]"):], ":")
+ if len(parts) < 2 {
+ return nil, fmt.Errorf("工具调用格式错误")
+ }
+
+ toolName := strings.TrimSpace(parts[0])
+ argsStr := strings.TrimSpace(parts[1])
+
+ args := make(map[string]interface{})
+ argPairs := strings.Split(argsStr, ",")
+ for _, pair := range argPairs {
+ kv := strings.Split(pair, "=")
+ if len(kv) == 2 {
+ args[strings.TrimSpace(kv[0])] = strings.TrimSpace(kv[1])
+ }
+ }
+
+ args["_tool_name"] = toolName
+ return args, nil
+}
+
+// ToolExecutionResult 工具执行结果
+type ToolExecutionResult struct {
+ Result string
+ ExecutionID string
+}
+
+// executeToolViaMCP 通过MCP执行工具
+func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map[string]interface{}) (*ToolExecutionResult, error) {
+ a.logger.Info("通过MCP执行工具",
+ zap.String("tool", toolName),
+ zap.Any("args", args),
+ )
+
+ // 通过MCP服务器调用工具
+ result, executionID, err := a.mcpServer.CallTool(ctx, toolName, args)
+ if err != nil {
+ return nil, fmt.Errorf("工具执行失败: %w", err)
+ }
+
+ // 格式化结果
+ var resultText strings.Builder
+ for _, content := range result.Content {
+ resultText.WriteString(content.Text)
+ resultText.WriteString("\n")
+ }
+
+ return &ToolExecutionResult{
+ Result: resultText.String(),
+ ExecutionID: executionID,
+ }, nil
+}
+
diff --git a/internal/app/app.go b/internal/app/app.go
new file mode 100644
index 00000000..8f78d2b7
--- /dev/null
+++ b/internal/app/app.go
@@ -0,0 +1,163 @@
+package app
+
+import (
+ "fmt"
+ "net/http"
+ "os"
+ "path/filepath"
+
+ "cyberstrike-ai/internal/agent"
+ "cyberstrike-ai/internal/config"
+ "cyberstrike-ai/internal/database"
+ "cyberstrike-ai/internal/handler"
+ "cyberstrike-ai/internal/logger"
+ "cyberstrike-ai/internal/mcp"
+ "cyberstrike-ai/internal/security"
+
+ "github.com/gin-gonic/gin"
+ "go.uber.org/zap"
+)
+
+// App 应用
+type App struct {
+ config *config.Config
+ logger *logger.Logger
+ router *gin.Engine
+ mcpServer *mcp.Server
+ agent *agent.Agent
+ executor *security.Executor
+ db *database.DB
+}
+
+// New 创建新应用
+func New(cfg *config.Config, log *logger.Logger) (*App, error) {
+ gin.SetMode(gin.ReleaseMode)
+ router := gin.Default()
+
+ // CORS中间件
+ router.Use(corsMiddleware())
+
+ // 初始化数据库
+ dbPath := cfg.Database.Path
+ if dbPath == "" {
+ dbPath = "data/conversations.db"
+ }
+
+ // 确保目录存在
+ if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil {
+ return nil, fmt.Errorf("创建数据库目录失败: %w", err)
+ }
+
+ db, err := database.NewDB(dbPath, log.Logger)
+ if err != nil {
+ return nil, fmt.Errorf("初始化数据库失败: %w", err)
+ }
+
+ // 创建MCP服务器
+ mcpServer := mcp.NewServer(log.Logger)
+
+ // 创建安全工具执行器
+ executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger)
+
+ // 注册工具
+ executor.RegisterTools(mcpServer)
+
+ // 创建Agent
+ agent := agent.NewAgent(&cfg.OpenAI, mcpServer, log.Logger)
+
+ // 创建处理器
+ agentHandler := handler.NewAgentHandler(agent, db, log.Logger)
+ monitorHandler := handler.NewMonitorHandler(mcpServer, executor, log.Logger)
+ conversationHandler := handler.NewConversationHandler(db, log.Logger)
+
+ // 设置路由
+ setupRoutes(router, agentHandler, monitorHandler, conversationHandler, mcpServer)
+
+ return &App{
+ config: cfg,
+ logger: log,
+ router: router,
+ mcpServer: mcpServer,
+ agent: agent,
+ executor: executor,
+ db: db,
+ }, nil
+}
+
+// Run 启动应用
+func (a *App) Run() error {
+ // 启动MCP服务器(如果启用)
+ if a.config.MCP.Enabled {
+ go func() {
+ mcpAddr := fmt.Sprintf("%s:%d", a.config.MCP.Host, a.config.MCP.Port)
+ a.logger.Info("启动MCP服务器", zap.String("address", mcpAddr))
+
+ mux := http.NewServeMux()
+ mux.HandleFunc("/mcp", a.mcpServer.HandleHTTP)
+
+ if err := http.ListenAndServe(mcpAddr, mux); err != nil {
+ a.logger.Error("MCP服务器启动失败", zap.Error(err))
+ }
+ }()
+ }
+
+ // 启动主服务器
+ addr := fmt.Sprintf("%s:%d", a.config.Server.Host, a.config.Server.Port)
+ a.logger.Info("启动HTTP服务器", zap.String("address", addr))
+
+ return a.router.Run(addr)
+}
+
+// setupRoutes 设置路由
+func setupRoutes(router *gin.Engine, agentHandler *handler.AgentHandler, monitorHandler *handler.MonitorHandler, conversationHandler *handler.ConversationHandler, mcpServer *mcp.Server) {
+ // API路由
+ api := router.Group("/api")
+ {
+ // Agent Loop
+ api.POST("/agent-loop", agentHandler.AgentLoop)
+
+ // 对话历史
+ api.POST("/conversations", conversationHandler.CreateConversation)
+ api.GET("/conversations", conversationHandler.ListConversations)
+ api.GET("/conversations/:id", conversationHandler.GetConversation)
+ api.DELETE("/conversations/:id", conversationHandler.DeleteConversation)
+
+ // 监控
+ api.GET("/monitor", monitorHandler.Monitor)
+ api.GET("/monitor/execution/:id", monitorHandler.GetExecution)
+ api.GET("/monitor/stats", monitorHandler.GetStats)
+ api.GET("/monitor/vulnerabilities", monitorHandler.GetVulnerabilities)
+
+ // MCP端点
+ api.POST("/mcp", func(c *gin.Context) {
+ mcpServer.HandleHTTP(c.Writer, c.Request)
+ })
+ }
+
+ // 静态文件
+ router.Static("/static", "./web/static")
+ router.LoadHTMLGlob("web/templates/*")
+
+ // 前端页面
+ router.GET("/", func(c *gin.Context) {
+ c.HTML(http.StatusOK, "index.html", nil)
+ })
+}
+
+// corsMiddleware CORS中间件
+func corsMiddleware() gin.HandlerFunc {
+ return func(c *gin.Context) {
+ c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
+ c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
+ c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
+ c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE")
+
+ if c.Request.Method == "OPTIONS" {
+ c.AbortWithStatus(204)
+ return
+ }
+
+ c.Next()
+ }
+}
+
diff --git a/internal/config/config.go b/internal/config/config.go
new file mode 100644
index 00000000..2215bce7
--- /dev/null
+++ b/internal/config/config.go
@@ -0,0 +1,114 @@
+package config
+
+import (
+ "fmt"
+ "os"
+
+ "gopkg.in/yaml.v3"
+)
+
+type Config struct {
+ Server ServerConfig `yaml:"server"`
+ Log LogConfig `yaml:"log"`
+ MCP MCPConfig `yaml:"mcp"`
+ OpenAI OpenAIConfig `yaml:"openai"`
+ Security SecurityConfig `yaml:"security"`
+ Database DatabaseConfig `yaml:"database"`
+}
+
+type ServerConfig struct {
+ Host string `yaml:"host"`
+ Port int `yaml:"port"`
+}
+
+type LogConfig struct {
+ Level string `yaml:"level"`
+ Output string `yaml:"output"`
+}
+
+type MCPConfig struct {
+ Enabled bool `yaml:"enabled"`
+ Host string `yaml:"host"`
+ Port int `yaml:"port"`
+}
+
+type OpenAIConfig struct {
+ APIKey string `yaml:"api_key"`
+ BaseURL string `yaml:"base_url"`
+ Model string `yaml:"model"`
+}
+
+type SecurityConfig struct {
+ Tools []ToolConfig `yaml:"tools"`
+}
+
+type DatabaseConfig struct {
+ Path string `yaml:"path"`
+}
+
+type ToolConfig struct {
+ Name string `yaml:"name"`
+ Command string `yaml:"command"`
+ Args []string `yaml:"args,omitempty"` // 固定参数(可选)
+ Description string `yaml:"description"`
+ Enabled bool `yaml:"enabled"`
+ Parameters []ParameterConfig `yaml:"parameters,omitempty"` // 参数定义(可选)
+ ArgMapping string `yaml:"arg_mapping,omitempty"` // 参数映射方式: "auto", "manual", "template"(可选)
+}
+
+// ParameterConfig 参数配置
+type ParameterConfig struct {
+ Name string `yaml:"name"` // 参数名称
+ Type string `yaml:"type"` // 参数类型: string, int, bool, array
+ Description string `yaml:"description"` // 参数描述
+ Required bool `yaml:"required,omitempty"` // 是否必需
+ Default interface{} `yaml:"default,omitempty"` // 默认值
+ Flag string `yaml:"flag,omitempty"` // 命令行标志,如 "-u", "--url", "-p"
+ Position *int `yaml:"position,omitempty"` // 位置参数的位置(从0开始)
+ Format string `yaml:"format,omitempty"` // 参数格式: "flag", "positional", "combined" (flag=value), "template"
+ Template string `yaml:"template,omitempty"` // 模板字符串,如 "{flag} {value}" 或 "{value}"
+ Options []string `yaml:"options,omitempty"` // 可选值列表(用于枚举)
+}
+
+func Load(path string) (*Config, error) {
+ data, err := os.ReadFile(path)
+ if err != nil {
+ return nil, fmt.Errorf("读取配置文件失败: %w", err)
+ }
+
+ var cfg Config
+ if err := yaml.Unmarshal(data, &cfg); err != nil {
+ return nil, fmt.Errorf("解析配置文件失败: %w", err)
+ }
+
+ return &cfg, nil
+}
+
+func Default() *Config {
+ return &Config{
+ Server: ServerConfig{
+ Host: "0.0.0.0",
+ Port: 8080,
+ },
+ Log: LogConfig{
+ Level: "info",
+ Output: "stdout",
+ },
+ MCP: MCPConfig{
+ Enabled: true,
+ Host: "0.0.0.0",
+ Port: 8081,
+ },
+ OpenAI: OpenAIConfig{
+ BaseURL: "https://api.openai.com/v1",
+ Model: "gpt-4",
+ },
+ Security: SecurityConfig{
+ Tools: []ToolConfig{}, // 工具配置应该从 config.yaml 加载,不在此硬编码
+ },
+ Database: DatabaseConfig{
+ Path: "data/conversations.db",
+ },
+ }
+}
+
diff --git a/internal/database/conversation.go b/internal/database/conversation.go
new file mode 100644
index 00000000..95905772
--- /dev/null
+++ b/internal/database/conversation.go
@@ -0,0 +1,256 @@
+package database
+
+import (
+ "database/sql"
+ "encoding/json"
+ "fmt"
+ "time"
+
+ "github.com/google/uuid"
+ "go.uber.org/zap"
+)
+
+// Conversation 对话
+type Conversation struct {
+ ID string `json:"id"`
+ Title string `json:"title"`
+ CreatedAt time.Time `json:"createdAt"`
+ UpdatedAt time.Time `json:"updatedAt"`
+ Messages []Message `json:"messages,omitempty"`
+}
+
+// Message 消息
+type Message struct {
+ ID string `json:"id"`
+ ConversationID string `json:"conversationId"`
+ Role string `json:"role"`
+ Content string `json:"content"`
+ MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"`
+ CreatedAt time.Time `json:"createdAt"`
+}
+
+// CreateConversation 创建新对话
+func (db *DB) CreateConversation(title string) (*Conversation, error) {
+ id := uuid.New().String()
+ now := time.Now()
+
+ _, err := db.Exec(
+ "INSERT INTO conversations (id, title, created_at, updated_at) VALUES (?, ?, ?, ?)",
+ id, title, now, now,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("创建对话失败: %w", err)
+ }
+
+ return &Conversation{
+ ID: id,
+ Title: title,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }, nil
+}
+
+// GetConversation 获取对话
+func (db *DB) GetConversation(id string) (*Conversation, error) {
+ var conv Conversation
+ var createdAt, updatedAt string
+
+ err := db.QueryRow(
+ "SELECT id, title, created_at, updated_at FROM conversations WHERE id = ?",
+ id,
+ ).Scan(&conv.ID, &conv.Title, &createdAt, &updatedAt)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return nil, fmt.Errorf("对话不存在")
+ }
+ return nil, fmt.Errorf("查询对话失败: %w", err)
+ }
+
+ // 尝试多种时间格式解析
+ var err1, err2 error
+ conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
+ if err1 != nil {
+ conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
+ }
+ if err1 != nil {
+ conv.CreatedAt, err1 = time.Parse(time.RFC3339, createdAt)
+ }
+
+ conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
+ if err2 != nil {
+ conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
+ }
+ if err2 != nil {
+ conv.UpdatedAt, err2 = time.Parse(time.RFC3339, updatedAt)
+ }
+
+ // 加载消息
+ messages, err := db.GetMessages(id)
+ if err != nil {
+ return nil, fmt.Errorf("加载消息失败: %w", err)
+ }
+ conv.Messages = messages
+
+ return &conv, nil
+}
+
+// ListConversations 列出所有对话
+func (db *DB) ListConversations(limit, offset int) ([]*Conversation, error) {
+ rows, err := db.Query(
+ "SELECT id, title, created_at, updated_at FROM conversations ORDER BY updated_at DESC LIMIT ? OFFSET ?",
+ limit, offset,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("查询对话列表失败: %w", err)
+ }
+ defer rows.Close()
+
+ var conversations []*Conversation
+ for rows.Next() {
+ var conv Conversation
+ var createdAt, updatedAt string
+
+ if err := rows.Scan(&conv.ID, &conv.Title, &createdAt, &updatedAt); err != nil {
+ return nil, fmt.Errorf("扫描对话失败: %w", err)
+ }
+
+ // 尝试多种时间格式解析
+ var err1, err2 error
+ conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
+ if err1 != nil {
+ conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
+ }
+ if err1 != nil {
+ conv.CreatedAt, err1 = time.Parse(time.RFC3339, createdAt)
+ }
+
+ conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
+ if err2 != nil {
+ conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
+ }
+ if err2 != nil {
+ conv.UpdatedAt, err2 = time.Parse(time.RFC3339, updatedAt)
+ }
+
+ conversations = append(conversations, &conv)
+ }
+
+ return conversations, nil
+}
+
+// UpdateConversationTitle 更新对话标题
+func (db *DB) UpdateConversationTitle(id, title string) error {
+ _, err := db.Exec(
+ "UPDATE conversations SET title = ?, updated_at = ? WHERE id = ?",
+ title, time.Now(), id,
+ )
+ if err != nil {
+ return fmt.Errorf("更新对话标题失败: %w", err)
+ }
+ return nil
+}
+
+// UpdateConversationTime 更新对话时间
+func (db *DB) UpdateConversationTime(id string) error {
+ _, err := db.Exec(
+ "UPDATE conversations SET updated_at = ? WHERE id = ?",
+ time.Now(), id,
+ )
+ if err != nil {
+ return fmt.Errorf("更新对话时间失败: %w", err)
+ }
+ return nil
+}
+
+// DeleteConversation 删除对话
+func (db *DB) DeleteConversation(id string) error {
+ _, err := db.Exec("DELETE FROM conversations WHERE id = ?", id)
+ if err != nil {
+ return fmt.Errorf("删除对话失败: %w", err)
+ }
+ return nil
+}
+
+// AddMessage 添加消息
+func (db *DB) AddMessage(conversationID, role, content string, mcpExecutionIDs []string) (*Message, error) {
+ id := uuid.New().String()
+
+ var mcpIDsJSON string
+ if len(mcpExecutionIDs) > 0 {
+ jsonData, err := json.Marshal(mcpExecutionIDs)
+ if err != nil {
+ db.logger.Warn("序列化MCP执行ID失败", zap.Error(err))
+ } else {
+ mcpIDsJSON = string(jsonData)
+ }
+ }
+
+ _, err := db.Exec(
+ "INSERT INTO messages (id, conversation_id, role, content, mcp_execution_ids, created_at) VALUES (?, ?, ?, ?, ?, ?)",
+ id, conversationID, role, content, mcpIDsJSON, time.Now(),
+ )
+ if err != nil {
+ return nil, fmt.Errorf("添加消息失败: %w", err)
+ }
+
+ // 更新对话时间
+ if err := db.UpdateConversationTime(conversationID); err != nil {
+ db.logger.Warn("更新对话时间失败", zap.Error(err))
+ }
+
+ message := &Message{
+ ID: id,
+ ConversationID: conversationID,
+ Role: role,
+ Content: content,
+ MCPExecutionIDs: mcpExecutionIDs,
+ CreatedAt: time.Now(),
+ }
+
+ return message, nil
+}
+
+// GetMessages 获取对话的所有消息
+func (db *DB) GetMessages(conversationID string) ([]Message, error) {
+ rows, err := db.Query(
+ "SELECT id, conversation_id, role, content, mcp_execution_ids, created_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC",
+ conversationID,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("查询消息失败: %w", err)
+ }
+ defer rows.Close()
+
+ var messages []Message
+ for rows.Next() {
+ var msg Message
+ var mcpIDsJSON sql.NullString
+ var createdAt string
+
+ if err := rows.Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &mcpIDsJSON, &createdAt); err != nil {
+ return nil, fmt.Errorf("扫描消息失败: %w", err)
+ }
+
+ // 尝试多种时间格式解析
+ var err error
+ msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
+ if err != nil {
+ msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05", createdAt)
+ }
+ if err != nil {
+ msg.CreatedAt, err = time.Parse(time.RFC3339, createdAt)
+ }
+
+ // 解析MCP执行ID
+ if mcpIDsJSON.Valid && mcpIDsJSON.String != "" {
+ if err := json.Unmarshal([]byte(mcpIDsJSON.String), &msg.MCPExecutionIDs); err != nil {
+ db.logger.Warn("解析MCP执行ID失败", zap.Error(err))
+ }
+ }
+
+ messages = append(messages, msg)
+ }
+
+ return messages, nil
+}
+
diff --git a/internal/database/database.go b/internal/database/database.go
new file mode 100644
index 00000000..5091e5ef
--- /dev/null
+++ b/internal/database/database.go
@@ -0,0 +1,90 @@
+package database
+
+import (
+ "database/sql"
+ "fmt"
+
+ _ "github.com/mattn/go-sqlite3"
+ "go.uber.org/zap"
+)
+
+// DB 数据库连接
+type DB struct {
+ *sql.DB
+ logger *zap.Logger
+}
+
+// NewDB 创建数据库连接
+func NewDB(dbPath string, logger *zap.Logger) (*DB, error) {
+ db, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1")
+ if err != nil {
+ return nil, fmt.Errorf("打开数据库失败: %w", err)
+ }
+
+ if err := db.Ping(); err != nil {
+ return nil, fmt.Errorf("连接数据库失败: %w", err)
+ }
+
+ database := &DB{
+ DB: db,
+ logger: logger,
+ }
+
+ // 初始化表
+ if err := database.initTables(); err != nil {
+ return nil, fmt.Errorf("初始化表失败: %w", err)
+ }
+
+ return database, nil
+}
+
+// initTables 初始化数据库表
+func (db *DB) initTables() error {
+ // 创建对话表
+ createConversationsTable := `
+ CREATE TABLE IF NOT EXISTS conversations (
+ id TEXT PRIMARY KEY,
+ title TEXT NOT NULL,
+ created_at DATETIME NOT NULL,
+ updated_at DATETIME NOT NULL
+ );`
+
+ // 创建消息表
+ createMessagesTable := `
+ CREATE TABLE IF NOT EXISTS messages (
+ id TEXT PRIMARY KEY,
+ conversation_id TEXT NOT NULL,
+ role TEXT NOT NULL,
+ content TEXT NOT NULL,
+ mcp_execution_ids TEXT,
+ created_at DATETIME NOT NULL,
+ FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
+ );`
+
+ // 创建索引
+ createIndexes := `
+ CREATE INDEX IF NOT EXISTS idx_messages_conversation_id ON messages(conversation_id);
+ CREATE INDEX IF NOT EXISTS idx_conversations_updated_at ON conversations(updated_at);
+ `
+
+ if _, err := db.Exec(createConversationsTable); err != nil {
+ return fmt.Errorf("创建conversations表失败: %w", err)
+ }
+
+ if _, err := db.Exec(createMessagesTable); err != nil {
+ return fmt.Errorf("创建messages表失败: %w", err)
+ }
+
+ if _, err := db.Exec(createIndexes); err != nil {
+ return fmt.Errorf("创建索引失败: %w", err)
+ }
+
+ db.logger.Info("数据库表初始化完成")
+ return nil
+}
+
+// Close 关闭数据库连接
+func (db *DB) Close() error {
+ return db.DB.Close()
+}
+
diff --git a/internal/handler/agent.go b/internal/handler/agent.go
new file mode 100644
index 00000000..b628b08c
--- /dev/null
+++ b/internal/handler/agent.go
@@ -0,0 +1,134 @@
+package handler
+
+import (
+ "net/http"
+ "time"
+
+ "cyberstrike-ai/internal/agent"
+ "cyberstrike-ai/internal/database"
+ "github.com/gin-gonic/gin"
+ "go.uber.org/zap"
+)
+
+// AgentHandler Agent处理器
+type AgentHandler struct {
+ agent *agent.Agent
+ db *database.DB
+ logger *zap.Logger
+}
+
+// NewAgentHandler 创建新的Agent处理器
+func NewAgentHandler(agent *agent.Agent, db *database.DB, logger *zap.Logger) *AgentHandler {
+ return &AgentHandler{
+ agent: agent,
+ db: db,
+ logger: logger,
+ }
+}
+
+// ChatRequest 聊天请求
+type ChatRequest struct {
+ Message string `json:"message" binding:"required"`
+ ConversationID string `json:"conversationId,omitempty"`
+}
+
+// ChatResponse 聊天响应
+type ChatResponse struct {
+ Response string `json:"response"`
+ MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"` // 本次对话中执行的MCP调用ID列表
+ ConversationID string `json:"conversationId"` // 对话ID
+ Time time.Time `json:"time"`
+}
+
+// AgentLoop 处理Agent Loop请求
+func (h *AgentHandler) AgentLoop(c *gin.Context) {
+ var req ChatRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+ return
+ }
+
+ h.logger.Info("收到Agent Loop请求",
+ zap.String("message", req.Message),
+ zap.String("conversationId", req.ConversationID),
+ )
+
+ // 如果没有对话ID,创建新对话
+ conversationID := req.ConversationID
+ if conversationID == "" {
+ title := req.Message
+ if len(title) > 50 {
+ title = title[:50] + "..."
+ }
+ conv, err := h.db.CreateConversation(title)
+ if err != nil {
+ h.logger.Error("创建对话失败", zap.Error(err))
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ return
+ }
+ conversationID = conv.ID
+ }
+
+ // 获取历史消息(排除当前消息,因为还没保存)
+ historyMessages, err := h.db.GetMessages(conversationID)
+ if err != nil {
+ h.logger.Warn("获取历史消息失败", zap.Error(err))
+ historyMessages = []database.Message{}
+ }
+
+ h.logger.Info("获取历史消息",
+ zap.String("conversationId", conversationID),
+ zap.Int("count", len(historyMessages)),
+ )
+
+ // 将数据库消息转换为Agent消息格式
+ agentHistoryMessages := make([]agent.ChatMessage, 0, len(historyMessages))
+ for i, msg := range historyMessages {
+ agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
+ Role: msg.Role,
+ Content: msg.Content,
+ })
+ contentPreview := msg.Content
+ if len(contentPreview) > 50 {
+ contentPreview = contentPreview[:50] + "..."
+ }
+ h.logger.Info("添加历史消息",
+ zap.Int("index", i),
+ zap.String("role", msg.Role),
+ zap.String("content", contentPreview),
+ )
+ }
+
+ h.logger.Info("历史消息转换完成",
+ zap.Int("originalCount", len(historyMessages)),
+ zap.Int("convertedCount", len(agentHistoryMessages)),
+ )
+
+ // 保存用户消息
+ _, err = h.db.AddMessage(conversationID, "user", req.Message, nil)
+ if err != nil {
+ h.logger.Error("保存用户消息失败", zap.Error(err))
+ }
+
+ // 执行Agent Loop,传入历史消息
+ result, err := h.agent.AgentLoop(c.Request.Context(), req.Message, agentHistoryMessages)
+ if err != nil {
+ h.logger.Error("Agent Loop执行失败", zap.Error(err))
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ return
+ }
+
+ // 保存助手回复
+ _, err = h.db.AddMessage(conversationID, "assistant", result.Response, result.MCPExecutionIDs)
+ if err != nil {
+ h.logger.Error("保存助手消息失败", zap.Error(err))
+ }
+
+ c.JSON(http.StatusOK, ChatResponse{
+ Response: result.Response,
+ MCPExecutionIDs: result.MCPExecutionIDs,
+ ConversationID: conversationID,
+ Time: time.Now(),
+ })
+}
+
diff --git a/internal/handler/conversation.go b/internal/handler/conversation.go
new file mode 100644
index 00000000..7e2a1d22
--- /dev/null
+++ b/internal/handler/conversation.go
@@ -0,0 +1,102 @@
+package handler
+
+import (
+ "net/http"
+ "strconv"
+
+ "cyberstrike-ai/internal/database"
+ "github.com/gin-gonic/gin"
+ "go.uber.org/zap"
+)
+
+// ConversationHandler 对话处理器
+type ConversationHandler struct {
+ db *database.DB
+ logger *zap.Logger
+}
+
+// NewConversationHandler 创建新的对话处理器
+func NewConversationHandler(db *database.DB, logger *zap.Logger) *ConversationHandler {
+ return &ConversationHandler{
+ db: db,
+ logger: logger,
+ }
+}
+
+// CreateConversationRequest 创建对话请求
+type CreateConversationRequest struct {
+ Title string `json:"title"`
+}
+
+// CreateConversation 创建新对话
+func (h *ConversationHandler) CreateConversation(c *gin.Context) {
+ var req CreateConversationRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+ return
+ }
+
+ title := req.Title
+ if title == "" {
+ title = "新对话"
+ }
+
+ conv, err := h.db.CreateConversation(title)
+ if err != nil {
+ h.logger.Error("创建对话失败", zap.Error(err))
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ return
+ }
+
+ c.JSON(http.StatusOK, conv)
+}
+
+// ListConversations 列出对话
+func (h *ConversationHandler) ListConversations(c *gin.Context) {
+ limitStr := c.DefaultQuery("limit", "50")
+ offsetStr := c.DefaultQuery("offset", "0")
+
+ limit, _ := strconv.Atoi(limitStr)
+ offset, _ := strconv.Atoi(offsetStr)
+
+ if limit <= 0 || limit > 100 {
+ limit = 50
+ }
+
+ conversations, err := h.db.ListConversations(limit, offset)
+ if err != nil {
+ h.logger.Error("获取对话列表失败", zap.Error(err))
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ return
+ }
+
+ c.JSON(http.StatusOK, conversations)
+}
+
+// GetConversation 获取对话
+func (h *ConversationHandler) GetConversation(c *gin.Context) {
+ id := c.Param("id")
+
+ conv, err := h.db.GetConversation(id)
+ if err != nil {
+ h.logger.Error("获取对话失败", zap.Error(err))
+ c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"})
+ return
+ }
+
+ c.JSON(http.StatusOK, conv)
+}
+
+// DeleteConversation 删除对话
+func (h *ConversationHandler) DeleteConversation(c *gin.Context) {
+ id := c.Param("id")
+
+ if err := h.db.DeleteConversation(id); err != nil {
+ h.logger.Error("删除对话失败", zap.Error(err))
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
+}
+
diff --git a/internal/handler/monitor.go b/internal/handler/monitor.go
new file mode 100644
index 00000000..a5893a1e
--- /dev/null
+++ b/internal/handler/monitor.go
@@ -0,0 +1,92 @@
+package handler
+
+import (
+ "net/http"
+ "time"
+
+ "cyberstrike-ai/internal/mcp"
+ "cyberstrike-ai/internal/security"
+ "github.com/gin-gonic/gin"
+ "go.uber.org/zap"
+)
+
+// MonitorHandler 监控处理器
+type MonitorHandler struct {
+ mcpServer *mcp.Server
+ executor *security.Executor
+ logger *zap.Logger
+ vulns []security.Vulnerability
+}
+
+// NewMonitorHandler 创建新的监控处理器
+func NewMonitorHandler(mcpServer *mcp.Server, executor *security.Executor, logger *zap.Logger) *MonitorHandler {
+ return &MonitorHandler{
+ mcpServer: mcpServer,
+ executor: executor,
+ logger: logger,
+ vulns: []security.Vulnerability{},
+ }
+}
+
+// MonitorResponse 监控响应
+type MonitorResponse struct {
+ Executions []*mcp.ToolExecution `json:"executions"`
+ Stats map[string]*mcp.ToolStats `json:"stats"`
+ Vulnerabilities []security.Vulnerability `json:"vulnerabilities"`
+ Report map[string]interface{} `json:"report"`
+ Timestamp time.Time `json:"timestamp"`
+}
+
+// Monitor 获取监控信息
+func (h *MonitorHandler) Monitor(c *gin.Context) {
+ // 获取所有执行记录
+ executions := h.mcpServer.GetAllExecutions()
+
+ // 分析执行结果,提取漏洞
+ for _, exec := range executions {
+ if exec.Status == "completed" && exec.Result != nil {
+ vulns := h.executor.AnalyzeResults(exec.ToolName, exec.Result)
+ h.vulns = append(h.vulns, vulns...)
+ }
+ }
+
+ // 获取统计信息
+ stats := h.mcpServer.GetStats()
+
+ // 生成报告
+ report := h.executor.GetVulnerabilityReport(h.vulns)
+
+ c.JSON(http.StatusOK, MonitorResponse{
+ Executions: executions,
+ Stats: stats,
+ Vulnerabilities: h.vulns,
+ Report: report,
+ Timestamp: time.Now(),
+ })
+}
+
+// GetExecution 获取特定执行记录
+func (h *MonitorHandler) GetExecution(c *gin.Context) {
+ id := c.Param("id")
+
+ exec, exists := h.mcpServer.GetExecution(id)
+ if !exists {
+ c.JSON(http.StatusNotFound, gin.H{"error": "执行记录未找到"})
+ return
+ }
+
+ c.JSON(http.StatusOK, exec)
+}
+
+// GetStats 获取统计信息
+func (h *MonitorHandler) GetStats(c *gin.Context) {
+ stats := h.mcpServer.GetStats()
+ c.JSON(http.StatusOK, stats)
+}
+
+// GetVulnerabilities 获取漏洞列表
+func (h *MonitorHandler) GetVulnerabilities(c *gin.Context) {
+ report := h.executor.GetVulnerabilityReport(h.vulns)
+ c.JSON(http.StatusOK, report)
+}
+
diff --git a/internal/logger/logger.go b/internal/logger/logger.go
new file mode 100644
index 00000000..549bceb0
--- /dev/null
+++ b/internal/logger/logger.go
@@ -0,0 +1,60 @@
+package logger
+
+import (
+ "os"
+
+ "go.uber.org/zap"
+ "go.uber.org/zap/zapcore"
+)
+
+type Logger struct {
+ *zap.Logger
+}
+
+func New(level, output string) *Logger {
+ var zapLevel zapcore.Level
+ switch level {
+ case "debug":
+ zapLevel = zapcore.DebugLevel
+ case "info":
+ zapLevel = zapcore.InfoLevel
+ case "warn":
+ zapLevel = zapcore.WarnLevel
+ case "error":
+ zapLevel = zapcore.ErrorLevel
+ default:
+ zapLevel = zapcore.InfoLevel
+ }
+
+ config := zap.NewProductionConfig()
+ config.Level = zap.NewAtomicLevelAt(zapLevel)
+ config.EncoderConfig.TimeKey = "timestamp"
+ config.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
+
+ var writeSyncer zapcore.WriteSyncer
+ if output == "stdout" {
+ writeSyncer = zapcore.AddSync(os.Stdout)
+ } else {
+ file, err := os.OpenFile(output, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
+ if err != nil {
+ writeSyncer = zapcore.AddSync(os.Stdout)
+ } else {
+ writeSyncer = zapcore.AddSync(file)
+ }
+ }
+
+ core := zapcore.NewCore(
+ zapcore.NewJSONEncoder(config.EncoderConfig),
+ writeSyncer,
+ zapLevel,
+ )
+
+ logger := zap.New(core, zap.AddCaller(), zap.AddStacktrace(zapcore.ErrorLevel))
+
+ return &Logger{Logger: logger}
+}
+
+func (l *Logger) Fatal(msg string, fields ...interface{}) {
+ l.Logger.Fatal(msg, zap.Any("fields", fields))
+}
+
diff --git a/internal/mcp/server.go b/internal/mcp/server.go
new file mode 100644
index 00000000..c9e53ea5
--- /dev/null
+++ b/internal/mcp/server.go
@@ -0,0 +1,798 @@
+package mcp
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/google/uuid"
+ "go.uber.org/zap"
+)
+
+// Server MCP服务器
+type Server struct {
+ tools map[string]ToolHandler
+ toolDefs map[string]Tool // 工具定义
+ executions map[string]*ToolExecution
+ stats map[string]*ToolStats
+ prompts map[string]*Prompt // 提示词模板
+ resources map[string]*Resource // 资源
+ mu sync.RWMutex
+ logger *zap.Logger
+}
+
+// ToolHandler 工具处理函数
+type ToolHandler func(ctx context.Context, args map[string]interface{}) (*ToolResult, error)
+
+// NewServer 创建新的MCP服务器
+func NewServer(logger *zap.Logger) *Server {
+ s := &Server{
+ tools: make(map[string]ToolHandler),
+ toolDefs: make(map[string]Tool),
+ executions: make(map[string]*ToolExecution),
+ stats: make(map[string]*ToolStats),
+ prompts: make(map[string]*Prompt),
+ resources: make(map[string]*Resource),
+ logger: logger,
+ }
+
+ // 初始化默认提示词和资源
+ s.initDefaultPrompts()
+ s.initDefaultResources()
+
+ return s
+}
+
+// RegisterTool 注册工具
+func (s *Server) RegisterTool(tool Tool, handler ToolHandler) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.tools[tool.Name] = handler
+ s.toolDefs[tool.Name] = tool
+
+ // 自动为工具创建资源文档
+ resourceURI := fmt.Sprintf("tool://%s", tool.Name)
+ s.resources[resourceURI] = &Resource{
+ URI: resourceURI,
+ Name: fmt.Sprintf("%s工具文档", tool.Name),
+ Description: tool.Description,
+ MimeType: "text/plain",
+ }
+}
+
+// HandleHTTP 处理HTTP请求
+func (s *Server) HandleHTTP(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ body, err := io.ReadAll(r.Body)
+ if err != nil {
+ s.sendError(w, nil, -32700, "Parse error", err.Error())
+ return
+ }
+
+ var msg Message
+ if err := json.Unmarshal(body, &msg); err != nil {
+ s.sendError(w, nil, -32700, "Parse error", err.Error())
+ return
+ }
+
+ // 处理消息
+ response := s.handleMessage(&msg)
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(response)
+}
+
+// handleMessage 处理MCP消息
+func (s *Server) handleMessage(msg *Message) *Message {
+ if msg.ID == "" {
+ msg.ID = uuid.New().String()
+ }
+
+ switch msg.Method {
+ case "initialize":
+ return s.handleInitialize(msg)
+ case "tools/list":
+ return s.handleListTools(msg)
+ case "tools/call":
+ return s.handleCallTool(msg)
+ case "prompts/list":
+ return s.handleListPrompts(msg)
+ case "prompts/get":
+ return s.handleGetPrompt(msg)
+ case "resources/list":
+ return s.handleListResources(msg)
+ case "resources/read":
+ return s.handleReadResource(msg)
+ case "sampling/request":
+ return s.handleSamplingRequest(msg)
+ default:
+ return &Message{
+ ID: msg.ID,
+ Type: MessageTypeError,
+ Error: &Error{Code: -32601, Message: "Method not found"},
+ }
+ }
+}
+
+// handleInitialize 处理初始化请求
+func (s *Server) handleInitialize(msg *Message) *Message {
+ var req InitializeRequest
+ if err := json.Unmarshal(msg.Params, &req); err != nil {
+ return &Message{
+ ID: msg.ID,
+ Type: MessageTypeError,
+ Error: &Error{Code: -32602, Message: "Invalid params"},
+ }
+ }
+
+ response := InitializeResponse{
+ ProtocolVersion: ProtocolVersion,
+ Capabilities: ServerCapabilities{
+ Tools: map[string]interface{}{
+ "listChanged": true,
+ },
+ Prompts: map[string]interface{}{
+ "listChanged": true,
+ },
+ Resources: map[string]interface{}{
+ "subscribe": true,
+ "listChanged": true,
+ },
+ Sampling: map[string]interface{}{},
+ },
+ ServerInfo: ServerInfo{
+ Name: "CyberStrikeAI",
+ Version: "1.0.0",
+ },
+ }
+
+ result, _ := json.Marshal(response)
+ return &Message{
+ ID: msg.ID,
+ Type: MessageTypeResponse,
+ Version: "2.0",
+ Result: result,
+ }
+}
+
+// handleListTools 处理列出工具请求
+func (s *Server) handleListTools(msg *Message) *Message {
+ s.mu.RLock()
+ tools := make([]Tool, 0, len(s.toolDefs))
+ for _, tool := range s.toolDefs {
+ tools = append(tools, tool)
+ }
+ s.mu.RUnlock()
+
+ response := ListToolsResponse{Tools: tools}
+ result, _ := json.Marshal(response)
+ return &Message{
+ ID: msg.ID,
+ Type: MessageTypeResponse,
+ Version: "2.0",
+ Result: result,
+ }
+}
+
+// handleCallTool 处理工具调用请求
+func (s *Server) handleCallTool(msg *Message) *Message {
+ var req CallToolRequest
+ if err := json.Unmarshal(msg.Params, &req); err != nil {
+ return &Message{
+ ID: msg.ID,
+ Type: MessageTypeError,
+ Error: &Error{Code: -32602, Message: "Invalid params"},
+ }
+ }
+
+ // 创建执行记录
+ executionID := uuid.New().String()
+ execution := &ToolExecution{
+ ID: executionID,
+ ToolName: req.Name,
+ Arguments: req.Arguments,
+ Status: "running",
+ StartTime: time.Now(),
+ }
+
+ s.mu.Lock()
+ s.executions[executionID] = execution
+ s.mu.Unlock()
+
+ // 更新统计
+ s.updateStats(req.Name, false)
+
+ // 执行工具
+ s.mu.RLock()
+ handler, exists := s.tools[req.Name]
+ s.mu.RUnlock()
+
+ if !exists {
+ execution.Status = "failed"
+ execution.Error = "Tool not found"
+ now := time.Now()
+ execution.EndTime = &now
+ return &Message{
+ ID: msg.ID,
+ Type: MessageTypeError,
+ Error: &Error{Code: -32601, Message: "Tool not found"},
+ }
+ }
+
+ // 同步执行所有工具,确保错误能正确返回
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
+ defer cancel()
+
+ s.logger.Info("开始执行工具",
+ zap.String("toolName", req.Name),
+ zap.Any("arguments", req.Arguments),
+ )
+
+ result, err := handler(ctx, req.Arguments)
+
+ s.mu.Lock()
+ now := time.Now()
+ execution.EndTime = &now
+ execution.Duration = now.Sub(execution.StartTime)
+
+ if err != nil {
+ execution.Status = "failed"
+ execution.Error = err.Error()
+ s.updateStats(req.Name, true)
+ s.mu.Unlock()
+
+ s.logger.Error("工具执行失败",
+ zap.String("toolName", req.Name),
+ zap.Error(err),
+ )
+
+ // 返回错误结果
+ errorResult, _ := json.Marshal(CallToolResponse{
+ Content: []Content{
+ {Type: "text", Text: fmt.Sprintf("工具执行失败: %v", err)},
+ },
+ IsError: true,
+ })
+ return &Message{
+ ID: msg.ID,
+ Type: MessageTypeResponse,
+ Version: "2.0",
+ Result: errorResult,
+ }
+ }
+
+ // 检查result是否为错误
+ if result != nil && result.IsError {
+ execution.Status = "failed"
+ if len(result.Content) > 0 {
+ execution.Error = result.Content[0].Text
+ }
+ s.updateStats(req.Name, true)
+ } else {
+ execution.Status = "completed"
+ execution.Result = result
+ s.updateStats(req.Name, false)
+ }
+ s.mu.Unlock()
+
+ // 返回执行结果
+ if result == nil {
+ result = &ToolResult{
+ Content: []Content{
+ {Type: "text", Text: "工具执行完成,但未返回结果"},
+ },
+ }
+ }
+
+ resultJSON, _ := json.Marshal(CallToolResponse{
+ Content: result.Content,
+ IsError: result.IsError,
+ })
+
+ s.logger.Info("工具执行完成",
+ zap.String("toolName", req.Name),
+ zap.Bool("isError", result.IsError),
+ )
+
+ return &Message{
+ ID: msg.ID,
+ Type: MessageTypeResponse,
+ Version: "2.0",
+ Result: resultJSON,
+ }
+}
+
+// updateStats 更新统计信息
+func (s *Server) updateStats(toolName string, failed bool) {
+ if s.stats[toolName] == nil {
+ s.stats[toolName] = &ToolStats{
+ ToolName: toolName,
+ }
+ }
+
+ stats := s.stats[toolName]
+ stats.TotalCalls++
+ now := time.Now()
+ stats.LastCallTime = &now
+
+ if failed {
+ stats.FailedCalls++
+ } else {
+ stats.SuccessCalls++
+ }
+}
+
+// GetExecution 获取执行记录
+func (s *Server) GetExecution(id string) (*ToolExecution, bool) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ exec, exists := s.executions[id]
+ return exec, exists
+}
+
+// GetAllExecutions 获取所有执行记录
+func (s *Server) GetAllExecutions() []*ToolExecution {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ executions := make([]*ToolExecution, 0, len(s.executions))
+ for _, exec := range s.executions {
+ executions = append(executions, exec)
+ }
+ return executions
+}
+
+// GetStats 获取统计信息
+func (s *Server) GetStats() map[string]*ToolStats {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ stats := make(map[string]*ToolStats)
+ for k, v := range s.stats {
+ stats[k] = v
+ }
+ return stats
+}
+
+// CallTool 直接调用工具(用于内部调用)
+func (s *Server) CallTool(ctx context.Context, toolName string, args map[string]interface{}) (*ToolResult, string, error) {
+ s.mu.RLock()
+ handler, exists := s.tools[toolName]
+ s.mu.RUnlock()
+
+ if !exists {
+ return nil, "", fmt.Errorf("工具 %s 未找到", toolName)
+ }
+
+ // 创建执行记录
+ executionID := uuid.New().String()
+ execution := &ToolExecution{
+ ID: executionID,
+ ToolName: toolName,
+ Arguments: args,
+ Status: "running",
+ StartTime: time.Now(),
+ }
+
+ s.mu.Lock()
+ s.executions[executionID] = execution
+ s.mu.Unlock()
+
+ // 更新统计
+ s.updateStats(toolName, false)
+
+ // 执行工具
+ result, err := handler(ctx, args)
+
+ s.mu.Lock()
+ now := time.Now()
+ execution.EndTime = &now
+ execution.Duration = now.Sub(execution.StartTime)
+
+ if err != nil {
+ execution.Status = "failed"
+ execution.Error = err.Error()
+ s.updateStats(toolName, true)
+ s.mu.Unlock()
+ return nil, executionID, err
+ } else {
+ execution.Status = "completed"
+ execution.Result = result
+ s.updateStats(toolName, false)
+ s.mu.Unlock()
+ return result, executionID, nil
+ }
+}
+
+// initDefaultPrompts 初始化默认提示词模板
+func (s *Server) initDefaultPrompts() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // 网络安全测试提示词
+ s.prompts["security_scan"] = &Prompt{
+ Name: "security_scan",
+ Description: "生成网络安全扫描任务的提示词",
+ Arguments: []PromptArgument{
+ {Name: "target", Description: "扫描目标(IP地址或域名)", Required: true},
+ {Name: "scan_type", Description: "扫描类型(port, vuln, web等)", Required: false},
+ },
+ }
+
+ // 渗透测试提示词
+ s.prompts["penetration_test"] = &Prompt{
+ Name: "penetration_test",
+ Description: "生成渗透测试任务的提示词",
+ Arguments: []PromptArgument{
+ {Name: "target", Description: "测试目标", Required: true},
+ {Name: "scope", Description: "测试范围", Required: false},
+ },
+ }
+}
+
+// initDefaultResources 初始化默认资源
+// 注意:工具资源现在在 RegisterTool 时自动创建,此函数保留用于其他非工具资源
+func (s *Server) initDefaultResources() {
+ // 工具资源已改为在 RegisterTool 时自动创建,无需在此硬编码
+}
+
+// handleListPrompts 处理列出提示词请求
+func (s *Server) handleListPrompts(msg *Message) *Message {
+ s.mu.RLock()
+ prompts := make([]Prompt, 0, len(s.prompts))
+ for _, prompt := range s.prompts {
+ prompts = append(prompts, *prompt)
+ }
+ s.mu.RUnlock()
+
+ response := ListPromptsResponse{
+ Prompts: prompts,
+ }
+ result, _ := json.Marshal(response)
+ return &Message{
+ ID: msg.ID,
+ Type: MessageTypeResponse,
+ Version: "2.0",
+ Result: result,
+ }
+}
+
+// handleGetPrompt 处理获取提示词请求
+func (s *Server) handleGetPrompt(msg *Message) *Message {
+ var req GetPromptRequest
+ if err := json.Unmarshal(msg.Params, &req); err != nil {
+ return &Message{
+ ID: msg.ID,
+ Type: MessageTypeError,
+ Error: &Error{Code: -32602, Message: "Invalid params"},
+ }
+ }
+
+ s.mu.RLock()
+ prompt, exists := s.prompts[req.Name]
+ s.mu.RUnlock()
+
+ if !exists {
+ return &Message{
+ ID: msg.ID,
+ Type: MessageTypeError,
+ Error: &Error{Code: -32601, Message: "Prompt not found"},
+ }
+ }
+
+ // 根据提示词名称生成消息
+ messages := s.generatePromptMessages(prompt, req.Arguments)
+
+ response := GetPromptResponse{
+ Messages: messages,
+ }
+ result, _ := json.Marshal(response)
+ return &Message{
+ ID: msg.ID,
+ Type: MessageTypeResponse,
+ Version: "2.0",
+ Result: result,
+ }
+}
+
+// generatePromptMessages 生成提示词消息
+func (s *Server) generatePromptMessages(prompt *Prompt, args map[string]interface{}) []PromptMessage {
+ messages := []PromptMessage{}
+
+ switch prompt.Name {
+ case "security_scan":
+ target, _ := args["target"].(string)
+ scanType, _ := args["scan_type"].(string)
+ if scanType == "" {
+ scanType = "comprehensive"
+ }
+
+ content := fmt.Sprintf(`请对目标 %s 执行%s安全扫描。包括:
+1. 端口扫描和服务识别
+2. 漏洞检测
+3. Web应用安全测试
+4. 生成详细的安全报告`, target, scanType)
+
+ messages = append(messages, PromptMessage{
+ Role: "user",
+ Content: content,
+ })
+
+ case "penetration_test":
+ target, _ := args["target"].(string)
+ scope, _ := args["scope"].(string)
+
+ content := fmt.Sprintf(`请对目标 %s 执行渗透测试。`, target)
+ if scope != "" {
+ content += fmt.Sprintf("测试范围:%s", scope)
+ }
+ content += "\n请按照OWASP Top 10进行全面的安全测试。"
+
+ messages = append(messages, PromptMessage{
+ Role: "user",
+ Content: content,
+ })
+
+ default:
+ messages = append(messages, PromptMessage{
+ Role: "user",
+ Content: "请执行安全测试任务",
+ })
+ }
+
+ return messages
+}
+
+// handleListResources 处理列出资源请求
+func (s *Server) handleListResources(msg *Message) *Message {
+ s.mu.RLock()
+ resources := make([]Resource, 0, len(s.resources))
+ for _, resource := range s.resources {
+ resources = append(resources, *resource)
+ }
+ s.mu.RUnlock()
+
+ response := ListResourcesResponse{
+ Resources: resources,
+ }
+ result, _ := json.Marshal(response)
+ return &Message{
+ ID: msg.ID,
+ Type: MessageTypeResponse,
+ Version: "2.0",
+ Result: result,
+ }
+}
+
+// handleReadResource 处理读取资源请求
+func (s *Server) handleReadResource(msg *Message) *Message {
+ var req ReadResourceRequest
+ if err := json.Unmarshal(msg.Params, &req); err != nil {
+ return &Message{
+ ID: msg.ID,
+ Type: MessageTypeError,
+ Error: &Error{Code: -32602, Message: "Invalid params"},
+ }
+ }
+
+ s.mu.RLock()
+ resource, exists := s.resources[req.URI]
+ s.mu.RUnlock()
+
+ if !exists {
+ return &Message{
+ ID: msg.ID,
+ Type: MessageTypeError,
+ Error: &Error{Code: -32601, Message: "Resource not found"},
+ }
+ }
+
+ // 生成资源内容
+ content := s.generateResourceContent(resource)
+
+ response := ReadResourceResponse{
+ Contents: []ResourceContent{content},
+ }
+ result, _ := json.Marshal(response)
+ return &Message{
+ ID: msg.ID,
+ Type: MessageTypeResponse,
+ Version: "2.0",
+ Result: result,
+ }
+}
+
+// generateResourceContent 生成资源内容
+func (s *Server) generateResourceContent(resource *Resource) ResourceContent {
+ content := ResourceContent{
+ URI: resource.URI,
+ MimeType: resource.MimeType,
+ }
+
+ // 如果是工具资源,生成详细文档
+ if strings.HasPrefix(resource.URI, "tool://") {
+ toolName := strings.TrimPrefix(resource.URI, "tool://")
+ content.Text = s.generateToolDocumentation(toolName, resource)
+ } else {
+ // 其他资源使用描述或默认内容
+ content.Text = resource.Description
+ }
+
+ return content
+}
+
+// generateToolDocumentation 生成工具文档
+func (s *Server) generateToolDocumentation(toolName string, resource *Resource) string {
+ // 获取工具定义以获取更详细的信息
+ s.mu.RLock()
+ tool, hasTool := s.toolDefs[toolName]
+ s.mu.RUnlock()
+
+ // 为常见工具生成详细文档
+ switch toolName {
+ case "nmap":
+ return `Nmap (Network Mapper) 是一个强大的网络扫描工具。
+
+主要功能:
+- 端口扫描:发现目标主机开放的端口
+- 服务识别:识别运行在端口上的服务
+- 版本检测:检测服务版本信息
+- 操作系统检测:识别目标操作系统
+
+常用命令:
+- nmap -sT target # TCP连接扫描
+- nmap -sV target # 版本检测
+- nmap -sC target # 默认脚本扫描
+- nmap -p 1-1000 target # 扫描指定端口范围
+
+参数说明:
+- target: 目标IP地址或域名(必需)
+- ports: 端口范围,例如: 1-1000(可选)`
+
+ case "sqlmap":
+ return `SQLMap 是一个自动化的SQL注入检测和利用工具。
+
+主要功能:
+- 自动检测SQL注入漏洞
+- 数据库指纹识别
+- 数据提取
+- 文件系统访问
+
+常用命令:
+- sqlmap -u "http://target.com/page?id=1" # 检测URL参数
+- sqlmap -u "http://target.com" --forms # 检测表单
+- sqlmap -u "http://target.com" --dbs # 列出数据库
+
+参数说明:
+- url: 目标URL(必需)`
+
+ case "nikto":
+ return `Nikto 是一个Web服务器扫描工具。
+
+主要功能:
+- Web服务器漏洞扫描
+- 检测过时的服务器软件
+- 检测危险文件和程序
+- 检测服务器配置问题
+
+常用命令:
+- nikto -h target # 扫描目标主机
+- nikto -h target -p 80,443 # 扫描指定端口
+
+参数说明:
+- target: 目标URL(必需)`
+
+ case "dirb":
+ return `Dirb 是一个Web内容扫描器。
+
+主要功能:
+- 扫描Web目录和文件
+- 发现隐藏的目录和文件
+- 支持自定义字典
+
+常用命令:
+- dirb url # 扫描目标URL
+- dirb url -w wordlist.txt # 使用自定义字典
+
+参数说明:
+- target: 目标URL(必需)`
+
+ case "exec":
+ return `Exec 工具用于执行系统命令。
+
+⚠️ 警告:此工具可以执行任意系统命令,请谨慎使用!
+
+参数说明:
+- command: 要执行的系统命令(必需)
+- shell: 使用的shell,默认为sh(可选)
+- workdir: 工作目录(可选)`
+
+ default:
+ // 对于其他工具,使用工具定义中的描述信息
+ if hasTool {
+ doc := fmt.Sprintf("%s\n\n", resource.Description)
+ if tool.InputSchema != nil {
+ if props, ok := tool.InputSchema["properties"].(map[string]interface{}); ok {
+ doc += "参数说明:\n"
+ for paramName, paramInfo := range props {
+ if paramMap, ok := paramInfo.(map[string]interface{}); ok {
+ if desc, ok := paramMap["description"].(string); ok {
+ doc += fmt.Sprintf("- %s: %s\n", paramName, desc)
+ }
+ }
+ }
+ }
+ }
+ return doc
+ }
+ return resource.Description
+ }
+}
+
+// handleSamplingRequest 处理采样请求
+func (s *Server) handleSamplingRequest(msg *Message) *Message {
+ var req SamplingRequest
+ if err := json.Unmarshal(msg.Params, &req); err != nil {
+ return &Message{
+ ID: msg.ID,
+ Type: MessageTypeError,
+ Error: &Error{Code: -32602, Message: "Invalid params"},
+ }
+ }
+
+ // 注意:采样功能通常需要连接到实际的LLM服务
+ // 这里返回一个占位符响应,实际实现需要集成LLM API
+ s.logger.Warn("Sampling request received but not fully implemented",
+ zap.Any("request", req),
+ )
+
+ response := SamplingResponse{
+ Content: []SamplingContent{
+ {
+ Type: "text",
+ Text: "采样功能需要配置LLM服务。请使用Agent Loop API进行AI对话。",
+ },
+ },
+ StopReason: "length",
+ }
+ result, _ := json.Marshal(response)
+ return &Message{
+ ID: msg.ID,
+ Type: MessageTypeResponse,
+ Version: "2.0",
+ Result: result,
+ }
+}
+
+// RegisterPrompt 注册提示词模板
+func (s *Server) RegisterPrompt(prompt *Prompt) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.prompts[prompt.Name] = prompt
+}
+
+// RegisterResource 注册资源
+func (s *Server) RegisterResource(resource *Resource) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.resources[resource.URI] = resource
+}
+
+// sendError 发送错误响应
+func (s *Server) sendError(w http.ResponseWriter, id interface{}, code int, message, data string) {
+ response := Message{
+ ID: fmt.Sprintf("%v", id),
+ Type: MessageTypeError,
+ Error: &Error{Code: code, Message: message, Data: data},
+ }
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(response)
+}
+
diff --git a/internal/mcp/types.go b/internal/mcp/types.go
new file mode 100644
index 00000000..a757196f
--- /dev/null
+++ b/internal/mcp/types.go
@@ -0,0 +1,232 @@
+package mcp
+
+import (
+ "encoding/json"
+ "time"
+)
+
+// MCP消息类型
+const (
+ MessageTypeRequest = "request"
+ MessageTypeResponse = "response"
+ MessageTypeError = "error"
+ MessageTypeNotify = "notify"
+)
+
+// MCP协议版本
+const ProtocolVersion = "2024-11-05"
+
+// Message 表示MCP消息
+type Message struct {
+ ID string `json:"id,omitempty"`
+ Type string `json:"type"`
+ Method string `json:"method,omitempty"`
+ Params json.RawMessage `json:"params,omitempty"`
+ Result json.RawMessage `json:"result,omitempty"`
+ Error *Error `json:"error,omitempty"`
+ Version string `json:"jsonrpc,omitempty"`
+}
+
+// Error 表示MCP错误
+type Error struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Data interface{} `json:"data,omitempty"`
+}
+
+// Tool 表示MCP工具定义
+type Tool struct {
+ Name string `json:"name"`
+ Description string `json:"description"`
+ InputSchema map[string]interface{} `json:"inputSchema"`
+}
+
+// ToolCall 表示工具调用
+type ToolCall struct {
+ Name string `json:"name"`
+ Arguments map[string]interface{} `json:"arguments"`
+}
+
+// ToolResult 表示工具执行结果
+type ToolResult struct {
+ Content []Content `json:"content"`
+ IsError bool `json:"isError,omitempty"`
+}
+
+// Content 表示内容
+type Content struct {
+ Type string `json:"type"`
+ Text string `json:"text"`
+}
+
+// InitializeRequest 初始化请求
+type InitializeRequest struct {
+ ProtocolVersion string `json:"protocolVersion"`
+ Capabilities map[string]interface{} `json:"capabilities"`
+ ClientInfo ClientInfo `json:"clientInfo"`
+}
+
+// ClientInfo 客户端信息
+type ClientInfo struct {
+ Name string `json:"name"`
+ Version string `json:"version"`
+}
+
+// InitializeResponse 初始化响应
+type InitializeResponse struct {
+ ProtocolVersion string `json:"protocolVersion"`
+ Capabilities ServerCapabilities `json:"capabilities"`
+ ServerInfo ServerInfo `json:"serverInfo"`
+}
+
+// ServerCapabilities 服务器能力
+type ServerCapabilities struct {
+ Tools map[string]interface{} `json:"tools,omitempty"`
+ Prompts map[string]interface{} `json:"prompts,omitempty"`
+ Resources map[string]interface{} `json:"resources,omitempty"`
+ Sampling map[string]interface{} `json:"sampling,omitempty"`
+}
+
+// ServerInfo 服务器信息
+type ServerInfo struct {
+ Name string `json:"name"`
+ Version string `json:"version"`
+}
+
+// ListToolsRequest 列出工具请求
+type ListToolsRequest struct{}
+
+// ListToolsResponse 列出工具响应
+type ListToolsResponse struct {
+ Tools []Tool `json:"tools"`
+}
+
+// ListPromptsResponse 列出提示词响应
+type ListPromptsResponse struct {
+ Prompts []Prompt `json:"prompts"`
+}
+
+// ListResourcesResponse 列出资源响应
+type ListResourcesResponse struct {
+ Resources []Resource `json:"resources"`
+}
+
+// CallToolRequest 调用工具请求
+type CallToolRequest struct {
+ Name string `json:"name"`
+ Arguments map[string]interface{} `json:"arguments"`
+}
+
+// CallToolResponse 调用工具响应
+type CallToolResponse struct {
+ Content []Content `json:"content"`
+ IsError bool `json:"isError,omitempty"`
+}
+
+// ToolExecution 工具执行记录
+type ToolExecution struct {
+ ID string `json:"id"`
+ ToolName string `json:"toolName"`
+ Arguments map[string]interface{} `json:"arguments"`
+ Status string `json:"status"` // pending, running, completed, failed
+ Result *ToolResult `json:"result,omitempty"`
+ Error string `json:"error,omitempty"`
+ StartTime time.Time `json:"startTime"`
+ EndTime *time.Time `json:"endTime,omitempty"`
+ Duration time.Duration `json:"duration,omitempty"`
+}
+
+// ToolStats 工具统计信息
+type ToolStats struct {
+ ToolName string `json:"toolName"`
+ TotalCalls int `json:"totalCalls"`
+ SuccessCalls int `json:"successCalls"`
+ FailedCalls int `json:"failedCalls"`
+ LastCallTime *time.Time `json:"lastCallTime,omitempty"`
+}
+
+// Prompt 提示词模板
+type Prompt struct {
+ Name string `json:"name"`
+ Description string `json:"description,omitempty"`
+ Arguments []PromptArgument `json:"arguments,omitempty"`
+}
+
+// PromptArgument 提示词参数
+type PromptArgument struct {
+ Name string `json:"name"`
+ Description string `json:"description,omitempty"`
+ Required bool `json:"required,omitempty"`
+}
+
+// GetPromptRequest 获取提示词请求
+type GetPromptRequest struct {
+ Name string `json:"name"`
+ Arguments map[string]interface{} `json:"arguments,omitempty"`
+}
+
+// GetPromptResponse 获取提示词响应
+type GetPromptResponse struct {
+ Messages []PromptMessage `json:"messages"`
+}
+
+// PromptMessage 提示词消息
+type PromptMessage struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+}
+
+// Resource 资源
+type Resource struct {
+ URI string `json:"uri"`
+ Name string `json:"name"`
+ Description string `json:"description,omitempty"`
+ MimeType string `json:"mimeType,omitempty"`
+}
+
+// ReadResourceRequest 读取资源请求
+type ReadResourceRequest struct {
+ URI string `json:"uri"`
+}
+
+// ReadResourceResponse 读取资源响应
+type ReadResourceResponse struct {
+ Contents []ResourceContent `json:"contents"`
+}
+
+// ResourceContent 资源内容
+type ResourceContent struct {
+ URI string `json:"uri"`
+ MimeType string `json:"mimeType,omitempty"`
+ Text string `json:"text,omitempty"`
+ Blob string `json:"blob,omitempty"`
+}
+
+// SamplingRequest 采样请求
+type SamplingRequest struct {
+ Messages []SamplingMessage `json:"messages"`
+ Model string `json:"model,omitempty"`
+ MaxTokens int `json:"maxTokens,omitempty"`
+ Temperature float64 `json:"temperature,omitempty"`
+ TopP float64 `json:"topP,omitempty"`
+}
+
+// SamplingMessage 采样消息
+type SamplingMessage struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+}
+
+// SamplingResponse 采样响应
+type SamplingResponse struct {
+ Content []SamplingContent `json:"content"`
+ Model string `json:"model,omitempty"`
+ StopReason string `json:"stopReason,omitempty"`
+}
+
+// SamplingContent 采样内容
+type SamplingContent struct {
+ Type string `json:"type"`
+ Text string `json:"text,omitempty"`
+}
+
diff --git a/internal/security/executor.go b/internal/security/executor.go
new file mode 100644
index 00000000..9131ea44
--- /dev/null
+++ b/internal/security/executor.go
@@ -0,0 +1,730 @@
+package security
+
+import (
+ "context"
+ "fmt"
+ "os/exec"
+ "strings"
+ "time"
+
+ "cyberstrike-ai/internal/config"
+ "cyberstrike-ai/internal/mcp"
+ "go.uber.org/zap"
+)
+
+// Executor 安全工具执行器
+type Executor struct {
+ config *config.SecurityConfig
+ mcpServer *mcp.Server
+ logger *zap.Logger
+}
+
+// NewExecutor 创建新的执行器
+func NewExecutor(cfg *config.SecurityConfig, mcpServer *mcp.Server, logger *zap.Logger) *Executor {
+ return &Executor{
+ config: cfg,
+ mcpServer: mcpServer,
+ logger: logger,
+ }
+}
+
+// ExecuteTool 执行安全工具
+func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[string]interface{}) (*mcp.ToolResult, error) {
+ e.logger.Info("ExecuteTool被调用",
+ zap.String("toolName", toolName),
+ zap.Any("args", args),
+ )
+
+ // 特殊处理:exec工具直接执行系统命令
+ if toolName == "exec" {
+ e.logger.Info("执行exec工具")
+ return e.executeSystemCommand(ctx, args)
+ }
+
+ // 查找工具配置
+ var toolConfig *config.ToolConfig
+ for i := range e.config.Tools {
+ if e.config.Tools[i].Name == toolName && e.config.Tools[i].Enabled {
+ toolConfig = &e.config.Tools[i]
+ break
+ }
+ }
+
+ if toolConfig == nil {
+ e.logger.Error("工具未找到或未启用",
+ zap.String("toolName", toolName),
+ zap.Int("totalTools", len(e.config.Tools)),
+ )
+ return nil, fmt.Errorf("工具 %s 未找到或未启用", toolName)
+ }
+
+ e.logger.Info("找到工具配置",
+ zap.String("toolName", toolName),
+ zap.String("command", toolConfig.Command),
+ zap.Strings("args", toolConfig.Args),
+ )
+
+ // 构建命令 - 根据工具类型使用不同的参数格式
+ cmdArgs := e.buildCommandArgs(toolName, toolConfig, args)
+
+ e.logger.Info("构建命令参数完成",
+ zap.String("toolName", toolName),
+ zap.Strings("cmdArgs", cmdArgs),
+ zap.Int("argsCount", len(cmdArgs)),
+ )
+
+ // 验证命令参数
+ if len(cmdArgs) == 0 {
+ e.logger.Warn("命令参数为空",
+ zap.String("toolName", toolName),
+ zap.Any("inputArgs", args),
+ )
+ return &mcp.ToolResult{
+ Content: []mcp.Content{
+ {
+ Type: "text",
+ Text: fmt.Sprintf("错误: 工具 %s 缺少必需的参数。接收到的参数: %v", toolName, args),
+ },
+ },
+ IsError: true,
+ }, nil
+ }
+
+ // 执行命令
+ cmd := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...)
+
+ e.logger.Info("执行安全工具",
+ zap.String("tool", toolName),
+ zap.Strings("args", cmdArgs),
+ )
+
+ output, err := cmd.CombinedOutput()
+ if err != nil {
+ e.logger.Error("工具执行失败",
+ zap.String("tool", toolName),
+ zap.Error(err),
+ zap.String("output", string(output)),
+ )
+ return &mcp.ToolResult{
+ Content: []mcp.Content{
+ {
+ Type: "text",
+ Text: fmt.Sprintf("工具执行失败: %v\n输出: %s", err, string(output)),
+ },
+ },
+ IsError: true,
+ }, nil
+ }
+
+ e.logger.Info("工具执行成功",
+ zap.String("tool", toolName),
+ zap.String("output", string(output)),
+ )
+
+ return &mcp.ToolResult{
+ Content: []mcp.Content{
+ {
+ Type: "text",
+ Text: string(output),
+ },
+ },
+ IsError: false,
+ }, nil
+}
+
+// RegisterTools 注册工具到MCP服务器
+func (e *Executor) RegisterTools(mcpServer *mcp.Server) {
+ e.logger.Info("开始注册工具",
+ zap.Int("totalTools", len(e.config.Tools)),
+ )
+
+ for i, toolConfig := range e.config.Tools {
+ if !toolConfig.Enabled {
+ e.logger.Debug("跳过未启用的工具",
+ zap.String("tool", toolConfig.Name),
+ )
+ continue
+ }
+
+ // 创建工具配置的副本,避免闭包问题
+ toolName := toolConfig.Name
+ toolConfigCopy := toolConfig
+
+ tool := mcp.Tool{
+ Name: toolConfigCopy.Name,
+ Description: toolConfigCopy.Description,
+ InputSchema: e.buildInputSchema(&toolConfigCopy),
+ }
+
+ handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
+ e.logger.Info("工具handler被调用",
+ zap.String("toolName", toolName),
+ zap.Any("args", args),
+ )
+ return e.ExecuteTool(ctx, toolName, args)
+ }
+
+ mcpServer.RegisterTool(tool, handler)
+ e.logger.Info("注册安全工具成功",
+ zap.String("tool", toolConfigCopy.Name),
+ zap.String("command", toolConfigCopy.Command),
+ zap.Int("index", i),
+ )
+ }
+
+ e.logger.Info("工具注册完成",
+ zap.Int("registeredCount", len(e.config.Tools)),
+ )
+}
+
+// buildCommandArgs 构建命令参数
+func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConfig, args map[string]interface{}) []string {
+ cmdArgs := make([]string, 0)
+
+ // 如果配置中定义了参数映射,使用配置中的映射规则
+ if len(toolConfig.Parameters) > 0 {
+ // 先添加固定参数
+ cmdArgs = append(cmdArgs, toolConfig.Args...)
+
+ // 按位置参数排序
+ positionalParams := make([]config.ParameterConfig, 0)
+ flagParams := make([]config.ParameterConfig, 0)
+
+ for _, param := range toolConfig.Parameters {
+ if param.Position != nil {
+ positionalParams = append(positionalParams, param)
+ } else {
+ flagParams = append(flagParams, param)
+ }
+ }
+
+ // 对位置参数按位置排序
+ for i := 0; i < len(positionalParams); i++ {
+ for _, param := range positionalParams {
+ if param.Position != nil && *param.Position == i {
+ value := e.getParamValue(args, param)
+ if value == nil {
+ if param.Required {
+ // 必需参数缺失,返回空数组让上层处理错误
+ e.logger.Warn("缺少必需的位置参数",
+ zap.String("tool", toolName),
+ zap.String("param", param.Name),
+ zap.Int("position", *param.Position),
+ )
+ return []string{}
+ }
+ break
+ }
+ cmdArgs = append(cmdArgs, e.formatParamValue(param, value))
+ break
+ }
+ }
+ }
+
+ // 处理标志参数
+ for _, param := range flagParams {
+ value := e.getParamValue(args, param)
+ if value == nil {
+ if param.Required {
+ // 必需参数缺失,返回空数组让上层处理错误
+ e.logger.Warn("缺少必需的标志参数",
+ zap.String("tool", toolName),
+ zap.String("param", param.Name),
+ )
+ return []string{}
+ }
+ continue
+ }
+
+ // 布尔值特殊处理:如果为 false,跳过;如果为 true,只添加标志
+ if param.Type == "bool" {
+ if boolVal, ok := value.(bool); ok {
+ if !boolVal {
+ continue // false 时不添加任何参数
+ }
+ // true 时只添加标志,不添加值
+ if param.Flag != "" {
+ cmdArgs = append(cmdArgs, param.Flag)
+ }
+ continue
+ }
+ }
+
+ format := param.Format
+ if format == "" {
+ format = "flag" // 默认格式
+ }
+
+ switch format {
+ case "flag":
+ // --flag value 或 -f value
+ if param.Flag != "" {
+ cmdArgs = append(cmdArgs, param.Flag)
+ }
+ formattedValue := e.formatParamValue(param, value)
+ if formattedValue != "" {
+ cmdArgs = append(cmdArgs, formattedValue)
+ }
+ case "combined":
+ // --flag=value 或 -f=value
+ if param.Flag != "" {
+ cmdArgs = append(cmdArgs, fmt.Sprintf("%s=%s", param.Flag, e.formatParamValue(param, value)))
+ } else {
+ cmdArgs = append(cmdArgs, e.formatParamValue(param, value))
+ }
+ case "template":
+ // 使用模板字符串
+ if param.Template != "" {
+ template := param.Template
+ template = strings.ReplaceAll(template, "{flag}", param.Flag)
+ template = strings.ReplaceAll(template, "{value}", e.formatParamValue(param, value))
+ template = strings.ReplaceAll(template, "{name}", param.Name)
+ cmdArgs = append(cmdArgs, strings.Fields(template)...)
+ } else {
+ // 如果没有模板,使用默认格式
+ if param.Flag != "" {
+ cmdArgs = append(cmdArgs, param.Flag)
+ }
+ cmdArgs = append(cmdArgs, e.formatParamValue(param, value))
+ }
+ case "positional":
+ // 位置参数(已在上面处理)
+ cmdArgs = append(cmdArgs, e.formatParamValue(param, value))
+ default:
+ // 默认:直接添加值
+ cmdArgs = append(cmdArgs, e.formatParamValue(param, value))
+ }
+ }
+
+ return cmdArgs
+ }
+
+ // 向后兼容:如果没有定义参数,使用旧的硬编码逻辑
+ switch toolName {
+ case "nmap":
+ // nmap -sT -sV -sC target [ports]
+ // 使用 -sT (TCP连接扫描) 而不是 -sS (SYN扫描),因为 -sS 需要root权限
+ e.logger.Debug("处理nmap参数",
+ zap.Any("args", args),
+ )
+
+ // 尝试多种方式获取target参数
+ var target string
+ var ok bool
+
+ // 方式1: 直接获取target
+ if target, ok = args["target"].(string); !ok || target == "" {
+ // 方式2: 尝试从tool字段获取(兼容某些格式)
+ if toolVal, exists := args["tool"]; exists {
+ if toolMap, ok := toolVal.(map[string]interface{}); ok {
+ if t, ok := toolMap["target"].(string); ok {
+ target = t
+ }
+ }
+ }
+ }
+
+ if target == "" {
+ e.logger.Warn("nmap缺少target参数",
+ zap.Any("args", args),
+ )
+ return cmdArgs // 返回空数组,让上层处理错误
+ }
+
+ e.logger.Debug("提取到target",
+ zap.String("target", target),
+ )
+
+ // 处理URL格式的目标(提取域名)
+ if strings.HasPrefix(target, "http://") || strings.HasPrefix(target, "https://") {
+ // 提取域名部分
+ target = strings.TrimPrefix(target, "http://")
+ target = strings.TrimPrefix(target, "https://")
+ // 移除路径部分
+ if idx := strings.Index(target, "/"); idx != -1 {
+ target = target[:idx]
+ }
+ }
+
+ // 添加扫描选项:-sT (TCP连接扫描,不需要root权限), -sV (版本检测), -sC (默认脚本)
+ cmdArgs = append(cmdArgs, "-sT", "-sV", "-sC")
+
+ // 添加端口范围(如果指定)
+ if ports, ok := args["ports"].(string); ok && ports != "" {
+ cmdArgs = append(cmdArgs, "-p", ports)
+ }
+
+ // 添加目标
+ cmdArgs = append(cmdArgs, target)
+
+ e.logger.Debug("nmap命令参数构建完成",
+ zap.Strings("cmdArgs", cmdArgs),
+ )
+ case "sqlmap":
+ // sqlmap -u url
+ if url, ok := args["url"].(string); ok {
+ cmdArgs = append(cmdArgs, "-u", url, "--batch", "--level=3", "--risk=2")
+ }
+ case "nikto":
+ // nikto -h target
+ if target, ok := args["target"].(string); ok {
+ cmdArgs = append(cmdArgs, "-h", target)
+ }
+ case "dirb":
+ // dirb url
+ if url, ok := args["url"].(string); ok {
+ cmdArgs = append(cmdArgs, url)
+ }
+ default:
+ // 通用处理
+ cmdArgs = append(cmdArgs, toolConfig.Args...)
+ for key, value := range args {
+ if key == "_tool_name" {
+ continue
+ }
+ cmdArgs = append(cmdArgs, fmt.Sprintf("--%s", key))
+ if strValue, ok := value.(string); ok {
+ cmdArgs = append(cmdArgs, strValue)
+ } else {
+ cmdArgs = append(cmdArgs, fmt.Sprintf("%v", value))
+ }
+ }
+ }
+
+ return cmdArgs
+}
+
+// getParamValue 获取参数值,支持默认值
+func (e *Executor) getParamValue(args map[string]interface{}, param config.ParameterConfig) interface{} {
+ // 从参数中获取值
+ if value, ok := args[param.Name]; ok && value != nil {
+ return value
+ }
+
+ // 如果参数是必需的但没有提供,返回 nil(让上层处理错误)
+ if param.Required {
+ return nil
+ }
+
+ // 返回默认值
+ return param.Default
+}
+
+// formatParamValue 格式化参数值
+func (e *Executor) formatParamValue(param config.ParameterConfig, value interface{}) string {
+ switch param.Type {
+ case "bool":
+ // 布尔值应该在上层处理,这里不应该被调用
+ if boolVal, ok := value.(bool); ok {
+ return fmt.Sprintf("%v", boolVal)
+ }
+ return "false"
+ case "array":
+ // 数组:转换为逗号分隔的字符串
+ if arr, ok := value.([]interface{}); ok {
+ strs := make([]string, 0, len(arr))
+ for _, item := range arr {
+ strs = append(strs, fmt.Sprintf("%v", item))
+ }
+ return strings.Join(strs, ",")
+ }
+ return fmt.Sprintf("%v", value)
+ default:
+ return fmt.Sprintf("%v", value)
+ }
+}
+
+// executeSystemCommand 执行系统命令
+func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
+ // 获取命令
+ command, ok := args["command"].(string)
+ if !ok {
+ return &mcp.ToolResult{
+ Content: []mcp.Content{
+ {
+ Type: "text",
+ Text: "错误: 缺少command参数",
+ },
+ },
+ IsError: true,
+ }, nil
+ }
+
+ if command == "" {
+ return &mcp.ToolResult{
+ Content: []mcp.Content{
+ {
+ Type: "text",
+ Text: "错误: command参数不能为空",
+ },
+ },
+ IsError: true,
+ }, nil
+ }
+
+ // 安全检查:记录执行的命令
+ e.logger.Warn("执行系统命令",
+ zap.String("command", command),
+ )
+
+ // 获取shell类型(可选,默认为sh)
+ shell := "sh"
+ if s, ok := args["shell"].(string); ok && s != "" {
+ shell = s
+ }
+
+ // 获取工作目录(可选)
+ workDir := ""
+ if wd, ok := args["workdir"].(string); ok && wd != "" {
+ workDir = wd
+ }
+
+ // 构建命令
+ var cmd *exec.Cmd
+ if workDir != "" {
+ cmd = exec.CommandContext(ctx, shell, "-c", command)
+ cmd.Dir = workDir
+ } else {
+ cmd = exec.CommandContext(ctx, shell, "-c", command)
+ }
+
+ // 执行命令
+ e.logger.Info("执行系统命令",
+ zap.String("command", command),
+ zap.String("shell", shell),
+ zap.String("workdir", workDir),
+ )
+
+ output, err := cmd.CombinedOutput()
+ if err != nil {
+ e.logger.Error("系统命令执行失败",
+ zap.String("command", command),
+ zap.Error(err),
+ zap.String("output", string(output)),
+ )
+ return &mcp.ToolResult{
+ Content: []mcp.Content{
+ {
+ Type: "text",
+ Text: fmt.Sprintf("命令执行失败: %v\n输出: %s", err, string(output)),
+ },
+ },
+ IsError: true,
+ }, nil
+ }
+
+ e.logger.Info("系统命令执行成功",
+ zap.String("command", command),
+ zap.String("output_length", fmt.Sprintf("%d", len(output))),
+ )
+
+ return &mcp.ToolResult{
+ Content: []mcp.Content{
+ {
+ Type: "text",
+ Text: string(output),
+ },
+ },
+ IsError: false,
+ }, nil
+}
+
+// buildInputSchema 构建输入模式
+func (e *Executor) buildInputSchema(toolConfig *config.ToolConfig) map[string]interface{} {
+ schema := map[string]interface{}{
+ "type": "object",
+ "properties": map[string]interface{}{},
+ "required": []string{},
+ }
+
+ // 如果配置中定义了参数,优先使用配置中的参数定义
+ if len(toolConfig.Parameters) > 0 {
+ properties := make(map[string]interface{})
+ required := []string{}
+
+ for _, param := range toolConfig.Parameters {
+ prop := map[string]interface{}{
+ "type": param.Type,
+ "description": param.Description,
+ }
+
+ // 添加默认值
+ if param.Default != nil {
+ prop["default"] = param.Default
+ }
+
+ // 添加枚举选项
+ if len(param.Options) > 0 {
+ prop["enum"] = param.Options
+ }
+
+ properties[param.Name] = prop
+
+ // 添加到必需参数列表
+ if param.Required {
+ required = append(required, param.Name)
+ }
+ }
+
+ schema["properties"] = properties
+ schema["required"] = required
+ return schema
+ }
+
+ // 向后兼容:如果没有定义参数,使用旧的硬编码逻辑
+ switch toolConfig.Name {
+ case "nmap":
+ schema["properties"] = map[string]interface{}{
+ "target": map[string]interface{}{
+ "type": "string",
+ "description": "目标IP地址或域名",
+ },
+ "ports": map[string]interface{}{
+ "type": "string",
+ "description": "端口范围,例如: 1-1000",
+ },
+ }
+ schema["required"] = []string{"target"}
+ case "sqlmap":
+ schema["properties"] = map[string]interface{}{
+ "url": map[string]interface{}{
+ "type": "string",
+ "description": "目标URL",
+ },
+ }
+ schema["required"] = []string{"url"}
+ case "nikto", "dirb":
+ schema["properties"] = map[string]interface{}{
+ "target": map[string]interface{}{
+ "type": "string",
+ "description": "目标URL",
+ },
+ }
+ schema["required"] = []string{"target"}
+ case "exec":
+ schema["properties"] = map[string]interface{}{
+ "command": map[string]interface{}{
+ "type": "string",
+ "description": "要执行的系统命令",
+ },
+ "shell": map[string]interface{}{
+ "type": "string",
+ "description": "使用的shell(可选,默认为sh)",
+ },
+ "workdir": map[string]interface{}{
+ "type": "string",
+ "description": "工作目录(可选)",
+ },
+ }
+ schema["required"] = []string{"command"}
+ }
+
+ return schema
+}
+
+// Vulnerability 漏洞信息
+type Vulnerability struct {
+ ID string `json:"id"`
+ Type string `json:"type"`
+ Severity string `json:"severity"` // low, medium, high, critical
+ Title string `json:"title"`
+ Description string `json:"description"`
+ Target string `json:"target"`
+ FoundAt time.Time `json:"foundAt"`
+ Details string `json:"details"`
+}
+
+// AnalyzeResults 分析工具执行结果,提取漏洞信息
+func (e *Executor) AnalyzeResults(toolName string, result *mcp.ToolResult) []Vulnerability {
+ vulnerabilities := []Vulnerability{}
+
+ if result.IsError {
+ return vulnerabilities
+ }
+
+ // 分析输出内容
+ for _, content := range result.Content {
+ if content.Type == "text" {
+ vulns := e.parseToolOutput(toolName, content.Text)
+ vulnerabilities = append(vulnerabilities, vulns...)
+ }
+ }
+
+ return vulnerabilities
+}
+
+// parseToolOutput 解析工具输出
+func (e *Executor) parseToolOutput(toolName, output string) []Vulnerability {
+ vulnerabilities := []Vulnerability{}
+
+ // 简单的漏洞检测逻辑
+ outputLower := strings.ToLower(output)
+
+ // SQL注入检测
+ if strings.Contains(outputLower, "sql injection") || strings.Contains(outputLower, "sqli") {
+ vulnerabilities = append(vulnerabilities, Vulnerability{
+ ID: fmt.Sprintf("sql-%d", time.Now().Unix()),
+ Type: "SQL Injection",
+ Severity: "high",
+ Title: "SQL注入漏洞",
+ Description: "检测到潜在的SQL注入漏洞",
+ FoundAt: time.Now(),
+ Details: output,
+ })
+ }
+
+ // XSS检测
+ if strings.Contains(outputLower, "xss") || strings.Contains(outputLower, "cross-site scripting") {
+ vulnerabilities = append(vulnerabilities, Vulnerability{
+ ID: fmt.Sprintf("xss-%d", time.Now().Unix()),
+ Type: "XSS",
+ Severity: "medium",
+ Title: "跨站脚本攻击漏洞",
+ Description: "检测到潜在的XSS漏洞",
+ FoundAt: time.Now(),
+ Details: output,
+ })
+ }
+
+ // 开放端口检测
+ if toolName == "nmap" {
+ lines := strings.Split(output, "\n")
+ for _, line := range lines {
+ if strings.Contains(line, "open") && strings.Contains(line, "port") {
+ vulnerabilities = append(vulnerabilities, Vulnerability{
+ ID: fmt.Sprintf("port-%d", time.Now().Unix()),
+ Type: "Open Port",
+ Severity: "low",
+ Title: "开放端口",
+ Description: fmt.Sprintf("发现开放端口: %s", line),
+ FoundAt: time.Now(),
+ Details: line,
+ })
+ }
+ }
+ }
+
+ return vulnerabilities
+}
+
+// GetVulnerabilityReport 生成漏洞报告
+func (e *Executor) GetVulnerabilityReport(vulnerabilities []Vulnerability) map[string]interface{} {
+ severityCount := map[string]int{
+ "critical": 0,
+ "high": 0,
+ "medium": 0,
+ "low": 0,
+ }
+
+ for _, vuln := range vulnerabilities {
+ severityCount[vuln.Severity]++
+ }
+
+ return map[string]interface{}{
+ "total": len(vulnerabilities),
+ "severityCount": severityCount,
+ "vulnerabilities": vulnerabilities,
+ "generatedAt": time.Now(),
+ }
+}
+
diff --git a/run.sh b/run.sh
new file mode 100644
index 00000000..a752063c
--- /dev/null
+++ b/run.sh
@@ -0,0 +1,35 @@
+#!/bin/bash
+
+# CyberStrikeAI 启动脚本
+
+echo "🚀 启动 CyberStrikeAI..."
+
+# 检查配置文件
+if [ ! -f "config.yaml" ]; then
+ echo "❌ 配置文件 config.yaml 不存在"
+ exit 1
+fi
+
+# 检查Go环境
+if ! command -v go &> /dev/null; then
+ echo "❌ Go 未安装,请先安装 Go 1.21 或更高版本"
+ exit 1
+fi
+
+# 下载依赖
+echo "📦 下载依赖..."
+go mod download
+
+# 构建项目
+echo "🔨 构建项目..."
+go build -o cyberstrike-ai cmd/server/main.go
+
+if [ $? -ne 0 ]; then
+ echo "❌ 构建失败"
+ exit 1
+fi
+
+# 运行服务器
+echo "✅ 启动服务器..."
+./cyberstrike-ai
+
diff --git a/web/static/css/style.css b/web/static/css/style.css
new file mode 100644
index 00000000..ad78f71d
--- /dev/null
+++ b/web/static/css/style.css
@@ -0,0 +1,691 @@
+* {
+ margin: 0;
+ padding: 0;
+ box-sizing: border-box;
+}
+
+:root {
+ --primary-color: #1a1a1a;
+ --secondary-color: #2d2d2d;
+ --accent-color: #0066ff;
+ --accent-hover: #0052cc;
+ --bg-primary: #ffffff;
+ --bg-secondary: #f8f9fa;
+ --bg-tertiary: #f1f3f5;
+ --text-primary: #1a1a1a;
+ --text-secondary: #6c757d;
+ --text-muted: #adb5bd;
+ --border-color: #e9ecef;
+ --success-color: #28a745;
+ --warning-color: #ffc107;
+ --error-color: #dc3545;
+ --shadow-sm: 0 1px 3px rgba(0, 0, 0, 0.05);
+ --shadow-md: 0 4px 6px rgba(0, 0, 0, 0.1);
+ --shadow-lg: 0 10px 25px rgba(0, 0, 0, 0.15);
+}
+
+body {
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Helvetica Neue', Arial, sans-serif;
+ background: var(--bg-secondary);
+ margin: 0;
+ padding: 0;
+ color: var(--text-primary);
+ line-height: 1.6;
+ height: 100vh;
+ overflow: hidden;
+}
+
+.container {
+ max-width: 100%;
+ margin: 0;
+ background: var(--bg-primary);
+ height: 100vh;
+ display: flex;
+ flex-direction: column;
+ box-shadow: var(--shadow-lg);
+ overflow: hidden;
+}
+
+.main-layout {
+ display: flex;
+ flex: 1;
+ overflow: hidden;
+ min-height: 0;
+}
+
+header {
+ background: var(--primary-color);
+ color: white;
+ padding: 24px 32px;
+ border-bottom: 1px solid rgba(255, 255, 255, 0.1);
+ flex-shrink: 0;
+}
+
+.header-content {
+ display: flex;
+ justify-content: space-between;
+ align-items: center;
+}
+
+.logo {
+ display: flex;
+ align-items: center;
+ gap: 12px;
+}
+
+.logo svg {
+ color: var(--accent-color);
+}
+
+.logo h1 {
+ font-size: 1.75rem;
+ font-weight: 600;
+ letter-spacing: -0.5px;
+ margin: 0;
+}
+
+.header-subtitle {
+ font-size: 0.875rem;
+ color: rgba(255, 255, 255, 0.7);
+ margin: 0;
+ font-weight: 400;
+}
+
+/* 侧边栏样式 */
+.sidebar {
+ width: 280px;
+ background: var(--bg-secondary);
+ border-right: 1px solid var(--border-color);
+ display: flex;
+ flex-direction: column;
+ flex-shrink: 0;
+ height: 100%;
+ overflow: hidden;
+}
+
+.sidebar-header {
+ padding: 16px;
+ border-bottom: 1px solid var(--border-color);
+ flex-shrink: 0;
+}
+
+.new-chat-btn {
+ width: 100%;
+ padding: 10px 16px;
+ background: var(--accent-color);
+ color: white;
+ border: none;
+ border-radius: 8px;
+ font-size: 0.9375rem;
+ font-weight: 500;
+ cursor: pointer;
+ transition: all 0.2s;
+ display: flex;
+ align-items: center;
+ justify-content: center;
+ gap: 8px;
+}
+
+.new-chat-btn:hover {
+ background: var(--accent-hover);
+ transform: translateY(-1px);
+ box-shadow: var(--shadow-sm);
+}
+
+.new-chat-btn span {
+ font-size: 1.2em;
+ line-height: 1;
+}
+
+.sidebar-content {
+ flex: 1;
+ overflow-y: auto;
+ overflow-x: hidden;
+ padding: 16px;
+ min-height: 0;
+}
+
+.sidebar-title {
+ font-size: 0.8125rem;
+ font-weight: 600;
+ color: var(--text-secondary);
+ text-transform: uppercase;
+ letter-spacing: 0.5px;
+ margin-bottom: 12px;
+ padding: 0 8px;
+}
+
+.conversations-list {
+ display: flex;
+ flex-direction: column;
+ gap: 4px;
+}
+
+.conversation-item {
+ padding: 12px;
+ border-radius: 8px;
+ cursor: pointer;
+ transition: all 0.2s;
+ border: 1px solid transparent;
+}
+
+.conversation-item:hover {
+ background: var(--bg-tertiary);
+}
+
+.conversation-item.active {
+ background: var(--bg-primary);
+ border-color: var(--accent-color);
+}
+
+.conversation-title {
+ font-size: 0.875rem;
+ font-weight: 500;
+ color: var(--text-primary);
+ margin-bottom: 4px;
+ overflow: hidden;
+ text-overflow: ellipsis;
+ white-space: nowrap;
+}
+
+.conversation-time {
+ font-size: 0.75rem;
+ color: var(--text-muted);
+}
+
+/* 对话界面样式 */
+.chat-container {
+ display: flex;
+ flex-direction: column;
+ flex: 1;
+ min-width: 0;
+ background: var(--bg-primary);
+ overflow: hidden;
+ height: 100%;
+}
+
+.chat-messages {
+ flex: 1;
+ overflow-y: auto;
+ overflow-x: hidden;
+ padding: 24px;
+ background: var(--bg-secondary);
+ display: flex;
+ flex-direction: column;
+ min-height: 0;
+}
+
+.message {
+ margin-bottom: 24px;
+ display: flex;
+ align-items: flex-start;
+ gap: 12px;
+ animation: fadeIn 0.3s ease-in;
+ width: 100%;
+}
+
+@keyframes fadeIn {
+ from {
+ opacity: 0;
+ transform: translateY(10px);
+ }
+ to {
+ opacity: 1;
+ transform: translateY(0);
+ }
+}
+
+.message.user {
+ flex-direction: row-reverse;
+ justify-content: flex-end;
+}
+
+.message.system {
+ justify-content: center;
+ margin-bottom: 16px;
+}
+
+.message-avatar {
+ width: 32px;
+ height: 32px;
+ border-radius: 6px;
+ display: flex;
+ align-items: center;
+ justify-content: center;
+ font-size: 0.75rem;
+ font-weight: 600;
+ flex-shrink: 0;
+ margin-top: 2px;
+}
+
+.message.user .message-avatar {
+ background: var(--accent-color);
+ color: white;
+}
+
+.message.assistant .message-avatar {
+ background: var(--bg-tertiary);
+ color: var(--text-secondary);
+ border: 1px solid var(--border-color);
+}
+
+.message.system .message-avatar {
+ display: none;
+}
+
+.message-content {
+ flex: 0 1 auto;
+ max-width: 70%;
+ min-width: 120px;
+ display: flex;
+ flex-direction: column;
+}
+
+.message.user .message-content {
+ align-items: flex-end;
+ margin-left: auto;
+}
+
+.message.assistant .message-content {
+ align-items: flex-start;
+ margin-right: auto;
+}
+
+.message.system .message-content {
+ max-width: 90%;
+ align-items: center;
+ margin: 0 auto;
+}
+
+.message-bubble {
+ padding: 12px 16px;
+ border-radius: 8px;
+ word-wrap: break-word;
+ word-break: break-word;
+ line-height: 1.6;
+ box-shadow: var(--shadow-sm);
+ white-space: pre-wrap;
+}
+
+.message.user .message-bubble {
+ background: var(--accent-color);
+ color: white;
+ border-bottom-right-radius: 2px;
+}
+
+.message.assistant .message-bubble {
+ background: var(--bg-primary);
+ color: var(--text-primary);
+ border: 1px solid var(--border-color);
+ border-bottom-left-radius: 2px;
+}
+
+.message.assistant .message-bubble pre {
+ margin: 0;
+ white-space: pre-wrap;
+ word-wrap: break-word;
+ font-family: inherit;
+}
+
+.message.system .message-bubble {
+ background: var(--bg-tertiary);
+ color: var(--text-secondary);
+ border: 1px solid var(--border-color);
+ text-align: center;
+ font-size: 0.875rem;
+ padding: 10px 16px;
+ width: 100%;
+}
+
+.message-time {
+ font-size: 0.6875rem;
+ color: var(--text-muted);
+ margin-top: 4px;
+ padding: 0 2px;
+ font-weight: 400;
+}
+
+/* MCP调用区域 */
+.mcp-call-section {
+ margin-top: 12px;
+ padding-top: 12px;
+ border-top: 1px solid var(--border-color);
+ width: 100%;
+}
+
+.mcp-call-label {
+ font-size: 0.75rem;
+ color: var(--text-secondary);
+ margin-bottom: 8px;
+ display: flex;
+ align-items: center;
+ gap: 6px;
+ font-weight: 500;
+}
+
+.mcp-call-label::before {
+ content: '';
+ width: 4px;
+ height: 4px;
+ background: var(--accent-color);
+ border-radius: 50%;
+ display: inline-block;
+ flex-shrink: 0;
+}
+
+.mcp-call-buttons {
+ display: flex;
+ flex-wrap: wrap;
+ gap: 6px;
+}
+
+.chat-input-container {
+ display: flex;
+ gap: 12px;
+ padding: 20px 24px;
+ background: var(--bg-primary);
+ border-top: 1px solid var(--border-color);
+ flex-shrink: 0;
+}
+
+.chat-input-container input {
+ flex: 1;
+ padding: 12px 16px;
+ border: 1px solid var(--border-color);
+ border-radius: 8px;
+ font-size: 0.9375rem;
+ outline: none;
+ transition: all 0.2s;
+ background: var(--bg-primary);
+ color: var(--text-primary);
+}
+
+.chat-input-container input:focus {
+ border-color: var(--accent-color);
+ box-shadow: 0 0 0 3px rgba(0, 102, 255, 0.1);
+}
+
+.chat-input-container input::placeholder {
+ color: var(--text-muted);
+}
+
+.chat-input-container button {
+ padding: 12px 24px;
+ background: var(--accent-color);
+ color: white;
+ border: none;
+ border-radius: 8px;
+ cursor: pointer;
+ font-size: 0.9375rem;
+ font-weight: 500;
+ transition: all 0.2s;
+ white-space: nowrap;
+}
+
+.chat-input-container button:hover {
+ background: var(--accent-hover);
+ transform: translateY(-1px);
+ box-shadow: var(--shadow-md);
+}
+
+.chat-input-container button:active {
+ transform: translateY(0);
+}
+
+/* MCP调用详情按钮 */
+.mcp-detail-btn {
+ display: inline-flex;
+ align-items: center;
+ gap: 6px;
+ padding: 6px 12px;
+ background: var(--bg-primary);
+ color: var(--accent-color);
+ border: 1px solid var(--border-color);
+ border-radius: 6px;
+ font-size: 0.8125rem;
+ font-weight: 500;
+ cursor: pointer;
+ transition: all 0.2s;
+}
+
+.mcp-detail-btn:hover {
+ background: var(--accent-color);
+ color: white;
+ border-color: var(--accent-color);
+ transform: translateY(-1px);
+ box-shadow: var(--shadow-sm);
+}
+
+.mcp-detail-btn:active {
+ transform: translateY(0);
+}
+
+/* 模态框样式 */
+.modal {
+ display: none;
+ position: fixed;
+ z-index: 1000;
+ left: 0;
+ top: 0;
+ width: 100%;
+ height: 100%;
+ background-color: rgba(0, 0, 0, 0.6);
+ backdrop-filter: blur(4px);
+ overflow: auto;
+ animation: fadeIn 0.2s ease-in;
+}
+
+.modal-content {
+ background-color: var(--bg-primary);
+ margin: 5% auto;
+ padding: 0;
+ border-radius: 12px;
+ width: 90%;
+ max-width: 900px;
+ max-height: 85vh;
+ display: flex;
+ flex-direction: column;
+ box-shadow: var(--shadow-lg);
+ border: 1px solid var(--border-color);
+ animation: slideDown 0.3s ease-out;
+}
+
+@keyframes slideDown {
+ from {
+ opacity: 0;
+ transform: translateY(-20px);
+ }
+ to {
+ opacity: 1;
+ transform: translateY(0);
+ }
+}
+
+.modal-header {
+ display: flex;
+ justify-content: space-between;
+ align-items: center;
+ padding: 20px 24px;
+ border-bottom: 1px solid var(--border-color);
+ background: var(--bg-primary);
+}
+
+.modal-header h2 {
+ margin: 0;
+ font-size: 1.25rem;
+ font-weight: 600;
+ color: var(--text-primary);
+}
+
+.modal-close {
+ width: 32px;
+ height: 32px;
+ display: flex;
+ align-items: center;
+ justify-content: center;
+ border-radius: 6px;
+ cursor: pointer;
+ color: var(--text-secondary);
+ font-size: 1.5rem;
+ line-height: 1;
+ transition: all 0.2s;
+ border: none;
+ background: transparent;
+}
+
+.modal-close:hover {
+ background: var(--bg-tertiary);
+ color: var(--text-primary);
+}
+
+.modal-body {
+ padding: 24px;
+ overflow-y: auto;
+ flex: 1;
+}
+
+.detail-section {
+ margin-bottom: 24px;
+}
+
+.detail-section:last-child {
+ margin-bottom: 0;
+}
+
+.detail-section h3 {
+ color: var(--text-primary);
+ margin-bottom: 12px;
+ padding-bottom: 8px;
+ border-bottom: 1px solid var(--border-color);
+ font-size: 0.9375rem;
+ font-weight: 600;
+ text-transform: uppercase;
+ letter-spacing: 0.5px;
+}
+
+.detail-item {
+ margin-bottom: 10px;
+ padding: 8px 0;
+ display: flex;
+ align-items: baseline;
+ gap: 8px;
+}
+
+.detail-item strong {
+ color: var(--text-secondary);
+ font-weight: 500;
+ font-size: 0.875rem;
+ min-width: 80px;
+}
+
+.detail-item span {
+ color: var(--text-primary);
+ font-size: 0.875rem;
+}
+
+.code-block {
+ background: var(--bg-tertiary);
+ border: 1px solid var(--border-color);
+ border-radius: 8px;
+ padding: 16px;
+ font-family: 'SF Mono', 'Monaco', 'Menlo', 'Ubuntu Mono', 'Consolas', monospace;
+ font-size: 0.8125rem;
+ line-height: 1.6;
+ overflow-x: auto;
+ white-space: pre-wrap;
+ word-wrap: break-word;
+ max-height: 400px;
+ overflow-y: auto;
+ color: var(--text-primary);
+}
+
+.code-block.error {
+ background: #fff5f5;
+ border-color: var(--error-color);
+ color: var(--error-color);
+}
+
+/* 滚动条样式 */
+::-webkit-scrollbar {
+ width: 8px;
+ height: 8px;
+}
+
+::-webkit-scrollbar-track {
+ background: var(--bg-secondary);
+ border-radius: 4px;
+}
+
+::-webkit-scrollbar-thumb {
+ background: var(--text-muted);
+ border-radius: 4px;
+}
+
+::-webkit-scrollbar-thumb:hover {
+ background: var(--text-secondary);
+}
+
+/* 响应式设计 */
+@media (max-width: 768px) {
+ body {
+ height: 100vh;
+ overflow: hidden;
+ }
+
+ .container {
+ border-radius: 0;
+ height: 100vh;
+ overflow: hidden;
+ }
+
+ header {
+ padding: 16px 20px;
+ flex-shrink: 0;
+ }
+
+ .logo h1 {
+ font-size: 1.5rem;
+ }
+
+ .header-subtitle {
+ display: none;
+ }
+
+ .main-layout {
+ overflow: hidden;
+ min-height: 0;
+ }
+
+ .sidebar {
+ height: 100%;
+ overflow: hidden;
+ }
+
+ .sidebar-content {
+ min-height: 0;
+ }
+
+ .chat-container {
+ height: 100%;
+ overflow: hidden;
+ }
+
+ .chat-messages {
+ padding: 16px;
+ min-height: 0;
+ }
+
+ .message-content {
+ max-width: 85%;
+ }
+
+ .chat-input-container {
+ padding: 16px;
+ flex-shrink: 0;
+ }
+
+ .modal-content {
+ width: 95%;
+ margin: 10% auto;
+ }
+}
diff --git a/web/static/js/app.js b/web/static/js/app.js
new file mode 100644
index 00000000..fbb16ba2
--- /dev/null
+++ b/web/static/js/app.js
@@ -0,0 +1,400 @@
+
+// 当前对话ID
+let currentConversationId = null;
+
+// 发送消息
+async function sendMessage() {
+ const input = document.getElementById('chat-input');
+ const message = input.value.trim();
+
+ if (!message) {
+ return;
+ }
+
+ // 显示用户消息
+ addMessage('user', message);
+ input.value = '';
+
+ // 显示加载状态
+ const loadingId = addMessage('system', '正在处理中...');
+
+ try {
+ const response = await fetch('/api/agent-loop', {
+ method: 'POST',
+ headers: {
+ 'Content-Type': 'application/json',
+ },
+ body: JSON.stringify({
+ message: message,
+ conversationId: currentConversationId
+ }),
+ });
+
+ const data = await response.json();
+
+ // 移除加载消息
+ removeMessage(loadingId);
+
+ if (response.ok) {
+ // 更新当前对话ID
+ if (data.conversationId) {
+ currentConversationId = data.conversationId;
+ updateActiveConversation();
+ }
+
+ // 如果有MCP执行ID,显示所有调用
+ const mcpIds = data.mcpExecutionIds || [];
+ addMessage('assistant', data.response, mcpIds);
+
+ // 刷新对话列表
+ loadConversations();
+ } else {
+ addMessage('system', '错误: ' + (data.error || '未知错误'));
+ }
+ } catch (error) {
+ removeMessage(loadingId);
+ addMessage('system', '错误: ' + error.message);
+ }
+}
+
+// 消息计数器,确保ID唯一
+let messageCounter = 0;
+
+// 添加消息
+function addMessage(role, content, mcpExecutionIds = null) {
+ const messagesDiv = document.getElementById('chat-messages');
+ const messageDiv = document.createElement('div');
+ messageCounter++;
+ const id = 'msg-' + Date.now() + '-' + messageCounter + '-' + Math.random().toString(36).substr(2, 9);
+ messageDiv.id = id;
+ messageDiv.className = 'message ' + role;
+
+ // 创建头像
+ const avatar = document.createElement('div');
+ avatar.className = 'message-avatar';
+ if (role === 'user') {
+ avatar.textContent = 'U';
+ } else if (role === 'assistant') {
+ avatar.textContent = 'A';
+ } else {
+ avatar.textContent = 'S';
+ }
+ messageDiv.appendChild(avatar);
+
+ // 创建消息内容容器
+ const contentWrapper = document.createElement('div');
+ contentWrapper.className = 'message-content';
+
+ // 创建消息气泡
+ const bubble = document.createElement('div');
+ bubble.className = 'message-bubble';
+ // 处理换行和格式化
+ const formattedContent = content.replace(/\n/g, '
');
+ bubble.innerHTML = formattedContent;
+ contentWrapper.appendChild(bubble);
+
+ // 添加时间戳
+ const timeDiv = document.createElement('div');
+ timeDiv.className = 'message-time';
+ timeDiv.textContent = new Date().toLocaleTimeString('zh-CN', { hour: '2-digit', minute: '2-digit' });
+ contentWrapper.appendChild(timeDiv);
+
+ // 如果有MCP执行ID,添加查看详情区域
+ if (mcpExecutionIds && Array.isArray(mcpExecutionIds) && mcpExecutionIds.length > 0 && role === 'assistant') {
+ const mcpSection = document.createElement('div');
+ mcpSection.className = 'mcp-call-section';
+
+ const mcpLabel = document.createElement('div');
+ mcpLabel.className = 'mcp-call-label';
+ mcpLabel.textContent = `工具调用 (${mcpExecutionIds.length})`;
+ mcpSection.appendChild(mcpLabel);
+
+ const buttonsContainer = document.createElement('div');
+ buttonsContainer.className = 'mcp-call-buttons';
+
+ mcpExecutionIds.forEach((execId, index) => {
+ const detailBtn = document.createElement('button');
+ detailBtn.className = 'mcp-detail-btn';
+ detailBtn.innerHTML = `调用 #${index + 1}`;
+ detailBtn.onclick = () => showMCPDetail(execId);
+ buttonsContainer.appendChild(detailBtn);
+ });
+
+ mcpSection.appendChild(buttonsContainer);
+ contentWrapper.appendChild(mcpSection);
+ }
+
+ messageDiv.appendChild(contentWrapper);
+ messagesDiv.appendChild(messageDiv);
+ messagesDiv.scrollTop = messagesDiv.scrollHeight;
+ return id;
+}
+
+// 移除消息
+function removeMessage(id) {
+ const messageDiv = document.getElementById(id);
+ if (messageDiv) {
+ messageDiv.remove();
+ }
+}
+
+// 回车发送消息
+document.getElementById('chat-input').addEventListener('keypress', function(e) {
+ if (e.key === 'Enter') {
+ sendMessage();
+ }
+});
+
+// 显示MCP调用详情
+async function showMCPDetail(executionId) {
+ try {
+ const response = await fetch(`/api/monitor/execution/${executionId}`);
+ const exec = await response.json();
+
+ if (response.ok) {
+ // 填充模态框内容
+ document.getElementById('detail-tool-name').textContent = exec.toolName || 'Unknown';
+ document.getElementById('detail-execution-id').textContent = exec.id || 'N/A';
+ document.getElementById('detail-status').textContent = getStatusText(exec.status);
+ document.getElementById('detail-time').textContent = new Date(exec.startTime).toLocaleString('zh-CN');
+
+ // 请求参数
+ const requestData = {
+ tool: exec.toolName,
+ arguments: exec.arguments
+ };
+ document.getElementById('detail-request').textContent = JSON.stringify(requestData, null, 2);
+
+ // 响应结果
+ if (exec.result) {
+ const responseData = {
+ content: exec.result.content,
+ isError: exec.result.isError
+ };
+ document.getElementById('detail-response').textContent = JSON.stringify(responseData, null, 2);
+ document.getElementById('detail-response').className = exec.result.isError ? 'code-block error' : 'code-block';
+ } else {
+ document.getElementById('detail-response').textContent = '暂无响应数据';
+ }
+
+ // 错误信息
+ if (exec.error) {
+ document.getElementById('detail-error-section').style.display = 'block';
+ document.getElementById('detail-error').textContent = exec.error;
+ } else {
+ document.getElementById('detail-error-section').style.display = 'none';
+ }
+
+ // 显示模态框
+ document.getElementById('mcp-detail-modal').style.display = 'block';
+ } else {
+ alert('获取详情失败: ' + (exec.error || '未知错误'));
+ }
+ } catch (error) {
+ alert('获取详情失败: ' + error.message);
+ }
+}
+
+// 关闭MCP详情模态框
+function closeMCPDetail() {
+ document.getElementById('mcp-detail-modal').style.display = 'none';
+}
+
+// 点击模态框外部关闭
+window.onclick = function(event) {
+ const modal = document.getElementById('mcp-detail-modal');
+ if (event.target == modal) {
+ closeMCPDetail();
+ }
+}
+
+// 工具函数
+function getStatusText(status) {
+ const statusMap = {
+ 'pending': '等待中',
+ 'running': '执行中',
+ 'completed': '已完成',
+ 'failed': '失败'
+ };
+ return statusMap[status] || status;
+}
+
+function formatDuration(ms) {
+ const seconds = Math.floor(ms / 1000);
+ const minutes = Math.floor(seconds / 60);
+ const hours = Math.floor(minutes / 60);
+
+ if (hours > 0) {
+ return `${hours}小时${minutes % 60}分钟`;
+ } else if (minutes > 0) {
+ return `${minutes}分钟${seconds % 60}秒`;
+ } else {
+ return `${seconds}秒`;
+ }
+}
+
+function escapeHtml(text) {
+ const div = document.createElement('div');
+ div.textContent = text;
+ return div.innerHTML;
+}
+
+// 开始新对话
+function startNewConversation() {
+ currentConversationId = null;
+ document.getElementById('chat-messages').innerHTML = '';
+ addMessage('assistant', '系统已就绪。请输入您的测试需求,系统将自动执行相应的安全测试。');
+ updateActiveConversation();
+}
+
+// 加载对话列表
+async function loadConversations() {
+ try {
+ const response = await fetch('/api/conversations?limit=50');
+ const conversations = await response.json();
+
+ const listContainer = document.getElementById('conversations-list');
+ listContainer.innerHTML = '';
+
+ if (conversations.length === 0) {
+ listContainer.innerHTML = '
安全测试平台
+