Add files via upload

This commit is contained in:
公明
2025-11-08 18:56:23 +08:00
committed by GitHub
commit add33e1cf7
24 changed files with 5228 additions and 0 deletions

319
README.md Normal file
View File

@@ -0,0 +1,319 @@
# CyberStrikeAI
基于Golang和Gin框架的AI驱动自主渗透测试平台使用MCP协议集成安全工具。
## 功能特性
- 🤖 **AI代理连接** - 支持Claude、GPT等兼容MCP的AI代理通过FastMCP协议连接
- 🧠 **智能分析** - 决策引擎分析目标并选择最佳测试策略
-**自主执行** - AI代理执行全面的安全评估
- 🔄 **实时适应** - 系统根据结果和发现的漏洞进行调整
- 📊 **高级报告** - 可视化方式输出漏洞卡片和风险分析
- 💬 **对话式交互** - 前端以对话形式调用后端agent-loop
- 📈 **实时监控** - 监控安全工具的执行状态、结果、调用次数等
## 项目结构
```
CyberStrikeAI/
├── cmd/
│ └── server/
│ └── main.go # 程序入口
├── internal/
│ ├── agent/ # AI代理模块
│ ├── app/ # 应用初始化
│ ├── config/ # 配置管理
│ ├── handler/ # HTTP处理器
│ ├── logger/ # 日志系统
│ ├── mcp/ # MCP协议实现
│ └── security/ # 安全工具执行器
├── web/
│ ├── static/ # 静态资源
│ │ ├── css/
│ │ └── js/
│ └── templates/ # HTML模板
├── config.yaml # 配置文件
├── go.mod # Go模块文件
└── README.md # 说明文档
```
## 快速开始
### 前置要求
- Go 1.21 或更高版本
- OpenAI API Key或其他兼容OpenAI协议的API
- 安全工具可选nmap, sqlmap, nikto, dirb
### 安装步骤
1. **克隆项目**
```bash
cd /Users/temp/Desktop/wenjian/tools/CyberStrikeAI
```
2. **安装依赖**
```bash
go mod download
```
3. **配置**
编辑 `config.yaml` 文件设置您的OpenAI API Key
```yaml
openai:
api_key: "sk-your-api-key-here"
base_url: "https://api.openai.com/v1"
model: "gpt-4"
```
4. **启动服务器**
#### 方式一:使用启动脚本
```bash
./run.sh
```
#### 方式二:直接运行
```bash
go run cmd/server/main.go
```
#### 方式三:编译后运行
```bash
go build -o cyberstrike-ai cmd/server/main.go
./cyberstrike-ai
```
5. **访问应用**
打开浏览器访问http://localhost:8080
## 配置说明
### 服务器配置
```yaml
server:
host: "0.0.0.0"
port: 8080
```
### MCP配置
```yaml
mcp:
enabled: true
host: "0.0.0.0"
port: 8081
```
### 安全工具配置
```yaml
security:
tools:
- name: "nmap"
command: "nmap"
args: ["-sV", "-sC"]
description: "网络扫描工具"
enabled: true
```
## 使用示例
### 对话式渗透测试
在"对话测试"标签页中,您可以:
1. **网络扫描**
```
扫描 192.168.1.1 的开放端口
```
2. **SQL注入检测**
```
检测 https://example.com 的SQL注入漏洞
```
3. **Web漏洞扫描**
```
扫描 https://example.com 的Web服务器漏洞
```
4. **目录扫描**
```
扫描 https://example.com 的隐藏目录
```
### 监控工具执行
在"工具监控"标签页中,您可以:
- 查看所有工具的执行统计
- 查看详细的执行记录
- 查看发现的漏洞列表
- 实时监控工具状态
## API接口
### Agent Loop API
**POST** `/api/agent-loop`
请求体:
```json
{
"message": "扫描 192.168.1.1"
}
```
使用示例:
```bash
curl -X POST http://localhost:8080/api/agent-loop \
-H "Content-Type: application/json" \
-d '{"message": "扫描 192.168.1.1"}'
```
### 监控API
- **GET** `/api/monitor` - 获取所有监控信息
- **GET** `/api/monitor/execution/:id` - 获取特定执行记录
- **GET** `/api/monitor/stats` - 获取统计信息
- **GET** `/api/monitor/vulnerabilities` - 获取漏洞列表
使用示例:
```bash
# 获取所有监控信息
curl http://localhost:8080/api/monitor
# 获取统计信息
curl http://localhost:8080/api/monitor/stats
# 获取漏洞列表
curl http://localhost:8080/api/monitor/vulnerabilities
```
### MCP接口
**POST** `/api/mcp` - MCP协议端点
## MCP协议
本项目实现了MCPModel Context Protocol协议支持
- `initialize` - 初始化连接
- `tools/list` - 列出可用工具
- `tools/call` - 调用工具
工具调用是异步执行的,系统会跟踪每个工具的执行状态和结果。
### MCP协议使用示例
#### 初始化连接
```bash
curl -X POST http://localhost:8080/api/mcp \
-H "Content-Type: application/json" \
-d '{
"jsonrpc": "2.0",
"id": "1",
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {
"name": "test-client",
"version": "1.0.0"
}
}
}'
```
#### 列出工具
```bash
curl -X POST http://localhost:8080/api/mcp \
-H "Content-Type: application/json" \
-d '{
"jsonrpc": "2.0",
"id": "2",
"method": "tools/list"
}'
```
#### 调用工具
```bash
curl -X POST http://localhost:8080/api/mcp \
-H "Content-Type: application/json" \
-d '{
"jsonrpc": "2.0",
"id": "3",
"method": "tools/call",
"params": {
"name": "nmap",
"arguments": {
"target": "192.168.1.1",
"ports": "1-1000"
}
}
}'
```
## 安全工具支持
当前支持的安全工具:
- **nmap** - 网络扫描
- **sqlmap** - SQL注入检测
- **nikto** - Web服务器扫描
- **dirb** - 目录扫描
可以通过修改 `config.yaml` 添加更多工具。
## 故障排除
### 问题无法连接到OpenAI API
- 检查API Key是否正确
- 检查网络连接
- 检查base_url配置
### 问题:工具执行失败
- 确保已安装相应的安全工具nmap, sqlmap等
- 检查工具是否在PATH中
- 某些工具可能需要root权限
### 问题:前端无法加载
- 检查服务器是否正常运行
- 检查端口8080是否被占用
- 查看浏览器控制台错误信息
## 安全注意事项
⚠️ **重要提示**
- 仅对您拥有或已获得授权的系统进行测试
- 遵守相关法律法规
- 建议在隔离的测试环境中使用
- 不要在生产环境中使用
- 某些安全工具可能需要root权限
## 开发
### 添加新工具
1. 在 `config.yaml` 中添加工具配置
2. 在 `internal/security/executor.go` 的 `buildCommandArgs` 方法中添加参数构建逻辑
3. 在 `internal/agent/agent.go` 的 `getAvailableTools` 方法中添加工具定义
### 构建
```bash
go build -o cyberstrike-ai cmd/server/main.go
```
## 许可证
本项目仅供学习和研究使用。
## 贡献
欢迎提交Issue和Pull Request

36
cmd/server/main.go Normal file
View File

@@ -0,0 +1,36 @@
package main
import (
"cyberstrike-ai/internal/app"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/logger"
"flag"
"fmt"
)
func main() {
var configPath = flag.String("config", "config.yaml", "配置文件路径")
flag.Parse()
// 加载配置
cfg, err := config.Load(*configPath)
if err != nil {
fmt.Printf("加载配置失败: %v\n", err)
return
}
// 初始化日志
log := logger.New(cfg.Log.Level, cfg.Log.Output)
// 创建应用
application, err := app.New(cfg, log)
if err != nil {
log.Fatal("应用初始化失败", "error", err)
}
// 启动服务器
if err := application.Run(); err != nil {
log.Fatal("服务器启动失败", "error", err)
}
}

174
config.yaml Normal file
View File

@@ -0,0 +1,174 @@
server:
host: "0.0.0.0"
port: 8080
log:
level: "info"
output: "stdout"
mcp:
enabled: true
host: "0.0.0.0"
port: 8081
openai:
api_key: "sk-xxx" # 请设置您的OpenAI API Key
base_url: "https://api.deepseek.com/v1"
model: "deepseek-chat"
database:
path: "data/conversations.db"
security:
tools:
# 示例1: 使用参数定义的工具(推荐方式)
- name: "nmap"
command: "nmap"
args: ["-sT", "-sV", "-sC"] # 固定参数
description: "网络扫描工具,用于发现网络主机和服务"
enabled: true
parameters:
- name: "target"
type: "string"
description: "目标IP地址或域名"
required: true
position: 0 # 位置参数,放在最后
format: "positional"
- name: "ports"
type: "string"
description: "端口范围,例如: 1-1000, 80,443,8080"
required: false
flag: "-p"
format: "flag"
# 示例2: 标志参数工具
- name: "sqlmap"
command: "sqlmap"
description: "SQL注入检测和利用工具"
enabled: true
parameters:
- name: "url"
type: "string"
description: "目标URL例如: http://example.com/page?id=1"
required: true
flag: "-u"
format: "flag"
- name: "batch"
type: "bool"
description: "非交互模式"
required: false
default: true
flag: "--batch"
format: "flag"
- name: "level"
type: "int"
description: "测试级别 (1-5)"
required: false
default: 3
flag: "--level"
format: "combined" # --level=3
# 示例3: 位置参数工具
- name: "nikto"
command: "nikto"
description: "Web服务器扫描工具"
enabled: true
parameters:
- name: "target"
type: "string"
description: "目标URL或IP地址"
required: true
flag: "-h"
format: "flag"
# 示例4: 简单位置参数
- name: "dirb"
command: "dirb"
description: "Web目录扫描工具"
enabled: true
parameters:
- name: "url"
type: "string"
description: "目标URL"
required: true
position: 0
format: "positional"
- name: "wordlist"
type: "string"
description: "字典文件路径"
required: false
flag: "-w"
format: "flag"
# 示例5: 执行系统命令
- name: "exec"
command: "sh"
args: ["-c"]
description: "执行系统命令(谨慎使用)"
enabled: true
parameters:
- name: "command"
type: "string"
description: "要执行的系统命令"
required: true
position: 0
format: "positional"
- name: "shell"
type: "string"
description: "使用的shell可选默认为sh"
required: false
default: "sh"
- name: "workdir"
type: "string"
description: "工作目录"
required: false
# 示例6: 自定义工具 - 使用模板格式
- name: "custom_scanner"
command: "my-scanner"
description: "自定义扫描工具示例"
enabled: false # 默认禁用,需要时启用
parameters:
- name: "target"
type: "string"
description: "扫描目标"
required: true
flag: "--target"
format: "flag"
- name: "mode"
type: "string"
description: "扫描模式"
required: false
default: "normal"
options: ["normal", "aggressive", "stealth"] # 枚举值
flag: "--mode"
format: "combined" # --mode=normal
- name: "threads"
type: "int"
description: "线程数"
required: false
default: 10
flag: "-t"
format: "flag"
- name: "output"
type: "string"
description: "输出文件路径"
required: false
flag: "-o"
format: "template"
template: "-o {value}" # 自定义模板
- name: "verbose"
type: "bool"
description: "详细输出"
required: false
default: false
flag: "-v"
format: "flag" # 布尔值如果为true只添加-v不添加值
# 示例7: 向后兼容 - 不定义parameters使用旧的硬编码逻辑
# - name: "legacy_tool"
# command: "legacy"
# args: ["--option"]
# description: "旧工具(使用硬编码逻辑)"
# enabled: false

BIN
data/conversations.db Normal file

Binary file not shown.

BIN
data/conversations.db-shm Normal file

Binary file not shown.

BIN
data/conversations.db-wal Normal file

Binary file not shown.

38
go.mod Normal file
View File

@@ -0,0 +1,38 @@
module cyberstrike-ai
go 1.21
require (
github.com/gin-gonic/gin v1.9.1
github.com/google/uuid v1.5.0
github.com/mattn/go-sqlite3 v1.14.18
go.uber.org/zap v1.26.0
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/bytedance/sonic v1.9.1 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/crypto v0.14.0 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/sys v0.13.0 // indirect
golang.org/x/text v0.13.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect
)

96
go.sum Normal file
View File

@@ -0,0 +1,96 @@
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU=
github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.18 h1:JL0eqdCOq6DJVNPSvArO/bIV9/P7fbGrV00LZHc+5aI=
github.com/mattn/go-sqlite3 v1.14.18/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk=
go.uber.org/goleak v1.2.0/go.mod h1:XJYK+MuIchqpmGmUSAzotztawfKvYLUIgg7guXrwVUo=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo=
go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=

576
internal/agent/agent.go Normal file
View File

@@ -0,0 +1,576 @@
package agent
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"go.uber.org/zap"
)
// Agent AI代理
type Agent struct {
openAIClient *http.Client
config *config.OpenAIConfig
mcpServer *mcp.Server
logger *zap.Logger
}
// NewAgent 创建新的Agent
func NewAgent(cfg *config.OpenAIConfig, mcpServer *mcp.Server, logger *zap.Logger) *Agent {
return &Agent{
openAIClient: &http.Client{Timeout: 5 * time.Minute},
config: cfg,
mcpServer: mcpServer,
logger: logger,
}
}
// ChatMessage 聊天消息
type ChatMessage struct {
Role string `json:"role"`
Content string `json:"content,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
// MarshalJSON 自定义JSON序列化将tool_calls中的arguments转换为JSON字符串
func (cm ChatMessage) MarshalJSON() ([]byte, error) {
// 构建序列化结构
aux := map[string]interface{}{
"role": cm.Role,
}
// 添加content如果存在
if cm.Content != "" {
aux["content"] = cm.Content
}
// 添加tool_call_id如果存在
if cm.ToolCallID != "" {
aux["tool_call_id"] = cm.ToolCallID
}
// 转换tool_calls将arguments转换为JSON字符串
if len(cm.ToolCalls) > 0 {
toolCallsJSON := make([]map[string]interface{}, len(cm.ToolCalls))
for i, tc := range cm.ToolCalls {
// 将arguments转换为JSON字符串
argsJSON := ""
if tc.Function.Arguments != nil {
argsBytes, err := json.Marshal(tc.Function.Arguments)
if err != nil {
return nil, err
}
argsJSON = string(argsBytes)
}
toolCallsJSON[i] = map[string]interface{}{
"id": tc.ID,
"type": tc.Type,
"function": map[string]interface{}{
"name": tc.Function.Name,
"arguments": argsJSON,
},
}
}
aux["tool_calls"] = toolCallsJSON
}
return json.Marshal(aux)
}
// OpenAIRequest OpenAI API请求
type OpenAIRequest struct {
Model string `json:"model"`
Messages []ChatMessage `json:"messages"`
Tools []Tool `json:"tools,omitempty"`
}
// OpenAIResponse OpenAI API响应
type OpenAIResponse struct {
ID string `json:"id"`
Choices []Choice `json:"choices"`
Error *Error `json:"error,omitempty"`
}
// Choice 选择
type Choice struct {
Message MessageWithTools `json:"message"`
FinishReason string `json:"finish_reason"`
}
// MessageWithTools 带工具调用的消息
type MessageWithTools struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
// Tool OpenAI工具定义
type Tool struct {
Type string `json:"type"`
Function FunctionDefinition `json:"function"`
}
// FunctionDefinition 函数定义
type FunctionDefinition struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]interface{} `json:"parameters"`
}
// Error OpenAI错误
type Error struct {
Message string `json:"message"`
Type string `json:"type"`
}
// ToolCall 工具调用
type ToolCall struct {
ID string `json:"id"`
Type string `json:"type"`
Function FunctionCall `json:"function"`
}
// FunctionCall 函数调用
type FunctionCall struct {
Name string `json:"name"`
Arguments map[string]interface{} `json:"arguments"`
}
// UnmarshalJSON 自定义JSON解析处理arguments可能是字符串或对象的情况
func (fc *FunctionCall) UnmarshalJSON(data []byte) error {
type Alias FunctionCall
aux := &struct {
Name string `json:"name"`
Arguments interface{} `json:"arguments"`
*Alias
}{
Alias: (*Alias)(fc),
}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}
fc.Name = aux.Name
// 处理arguments可能是字符串或对象的情况
switch v := aux.Arguments.(type) {
case map[string]interface{}:
fc.Arguments = v
case string:
// 如果是字符串尝试解析为JSON
if err := json.Unmarshal([]byte(v), &fc.Arguments); err != nil {
// 如果解析失败创建一个包含原始字符串的map
fc.Arguments = map[string]interface{}{
"raw": v,
}
}
case nil:
fc.Arguments = make(map[string]interface{})
default:
// 其他类型尝试转换为map
fc.Arguments = map[string]interface{}{
"value": v,
}
}
return nil
}
// AgentLoopResult Agent Loop执行结果
type AgentLoopResult struct {
Response string
MCPExecutionIDs []string
}
// AgentLoop 执行Agent循环
func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages []ChatMessage) (*AgentLoopResult, error) {
messages := []ChatMessage{
{
Role: "system",
Content: "你是一个专业的网络安全渗透测试专家。你可以使用各种安全工具进行自主渗透测试。分析目标并选择最佳测试策略。当需要执行工具时,使用提供的工具函数。",
},
}
// 添加历史消息数据库只保存user和assistant消息
a.logger.Info("处理历史消息",
zap.Int("count", len(historyMessages)),
)
addedCount := 0
for i, msg := range historyMessages {
// 只添加有内容的消息
if msg.Content != "" {
messages = append(messages, ChatMessage{
Role: msg.Role,
Content: msg.Content,
})
addedCount++
contentPreview := msg.Content
if len(contentPreview) > 50 {
contentPreview = contentPreview[:50] + "..."
}
a.logger.Info("添加历史消息到上下文",
zap.Int("index", i),
zap.String("role", msg.Role),
zap.String("content", contentPreview),
)
}
}
a.logger.Info("构建消息数组",
zap.Int("historyMessages", len(historyMessages)),
zap.Int("addedMessages", addedCount),
zap.Int("totalMessages", len(messages)),
)
// 添加当前用户消息
messages = append(messages, ChatMessage{
Role: "user",
Content: userInput,
})
result := &AgentLoopResult{
MCPExecutionIDs: make([]string, 0),
}
maxIterations := 10
for i := 0; i < maxIterations; i++ {
// 获取可用工具
tools := a.getAvailableTools()
// 记录每次调用OpenAI
if i == 0 {
a.logger.Info("调用OpenAI",
zap.Int("iteration", i+1),
zap.Int("messagesCount", len(messages)),
)
// 记录前几条消息的内容(用于调试)
for j, msg := range messages {
if j >= 5 { // 只记录前5条
break
}
contentPreview := msg.Content
if len(contentPreview) > 100 {
contentPreview = contentPreview[:100] + "..."
}
a.logger.Debug("消息内容",
zap.Int("index", j),
zap.String("role", msg.Role),
zap.String("content", contentPreview),
)
}
} else {
a.logger.Info("调用OpenAI",
zap.Int("iteration", i+1),
zap.Int("messagesCount", len(messages)),
)
}
// 调用OpenAI
response, err := a.callOpenAI(ctx, messages, tools)
if err != nil {
result.Response = ""
return result, fmt.Errorf("调用OpenAI失败: %w", err)
}
if response.Error != nil {
result.Response = ""
return result, fmt.Errorf("OpenAI错误: %s", response.Error.Message)
}
if len(response.Choices) == 0 {
result.Response = ""
return result, fmt.Errorf("没有收到响应")
}
choice := response.Choices[0]
// 检查是否有工具调用
if len(choice.Message.ToolCalls) > 0 {
// 添加assistant消息包含工具调用
messages = append(messages, ChatMessage{
Role: "assistant",
Content: choice.Message.Content,
ToolCalls: choice.Message.ToolCalls,
})
// 执行所有工具调用
for _, toolCall := range choice.Message.ToolCalls {
// 执行工具
execResult, err := a.executeToolViaMCP(ctx, toolCall.Function.Name, toolCall.Function.Arguments)
if err != nil {
messages = append(messages, ChatMessage{
Role: "tool",
ToolCallID: toolCall.ID,
Content: fmt.Sprintf("工具执行失败: %v", err),
})
} else {
messages = append(messages, ChatMessage{
Role: "tool",
ToolCallID: toolCall.ID,
Content: execResult.Result,
})
// 收集执行ID
if execResult.ExecutionID != "" {
result.MCPExecutionIDs = append(result.MCPExecutionIDs, execResult.ExecutionID)
}
}
}
continue
}
// 添加assistant响应
messages = append(messages, ChatMessage{
Role: "assistant",
Content: choice.Message.Content,
})
// 如果完成,返回结果
if choice.FinishReason == "stop" {
result.Response = choice.Message.Content
return result, nil
}
}
result.Response = "达到最大迭代次数"
return result, nil
}
// getAvailableTools 获取可用工具
func (a *Agent) getAvailableTools() []Tool {
// 从MCP服务器获取工具列表
executions := a.mcpServer.GetAllExecutions()
toolNames := make(map[string]bool)
for _, exec := range executions {
toolNames[exec.ToolName] = true
}
tools := []Tool{
{
Type: "function",
Function: FunctionDefinition{
Name: "nmap",
Description: "使用nmap进行网络扫描发现开放端口和服务。支持IP地址、域名或URL会自动提取域名。使用TCP连接扫描不需要root权限。",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"target": map[string]interface{}{
"type": "string",
"description": "目标IP地址、域名或URL如 https://example.com。如果是URL会自动提取域名部分。",
},
"ports": map[string]interface{}{
"type": "string",
"description": "要扫描的端口范围,例如: 1-1000 或 80,443,8080。如果不指定将扫描常用端口。",
},
},
"required": []string{"target"},
},
},
},
{
Type: "function",
Function: FunctionDefinition{
Name: "sqlmap",
Description: "使用sqlmap检测SQL注入漏洞",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"url": map[string]interface{}{
"type": "string",
"description": "目标URL",
},
},
"required": []string{"url"},
},
},
},
{
Type: "function",
Function: FunctionDefinition{
Name: "nikto",
Description: "使用nikto扫描Web服务器漏洞",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"target": map[string]interface{}{
"type": "string",
"description": "目标URL",
},
},
"required": []string{"target"},
},
},
},
{
Type: "function",
Function: FunctionDefinition{
Name: "dirb",
Description: "使用dirb进行目录扫描",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"url": map[string]interface{}{
"type": "string",
"description": "目标URL",
},
},
"required": []string{"url"},
},
},
},
{
Type: "function",
Function: FunctionDefinition{
Name: "exec",
Description: "执行系统命令(谨慎使用,仅用于必要的系统操作)",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"command": map[string]interface{}{
"type": "string",
"description": "要执行的系统命令",
},
"shell": map[string]interface{}{
"type": "string",
"description": "使用的shell可选默认为sh",
},
"workdir": map[string]interface{}{
"type": "string",
"description": "工作目录(可选)",
},
},
"required": []string{"command"},
},
},
},
}
return tools
}
// callOpenAI 调用OpenAI API
func (a *Agent) callOpenAI(ctx context.Context, messages []ChatMessage, tools []Tool) (*OpenAIResponse, error) {
reqBody := OpenAIRequest{
Model: a.config.Model,
Messages: messages,
}
if len(tools) > 0 {
reqBody.Tools = tools
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, "POST", a.config.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+a.config.APIKey)
resp, err := a.openAIClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
// 记录响应内容(用于调试)
if resp.StatusCode != http.StatusOK {
a.logger.Warn("OpenAI API返回非200状态码",
zap.Int("status", resp.StatusCode),
zap.String("body", string(body)),
)
}
var response OpenAIResponse
if err := json.Unmarshal(body, &response); err != nil {
a.logger.Error("解析OpenAI响应失败",
zap.Error(err),
zap.String("body", string(body)),
)
return nil, fmt.Errorf("解析响应失败: %w, 响应内容: %s", err, string(body))
}
return &response, nil
}
// parseToolCall 解析工具调用
func (a *Agent) parseToolCall(content string) (map[string]interface{}, error) {
// 简单解析,实际应该更复杂
// 格式: [TOOL_CALL]tool_name:arg1=value1,arg2=value2
if !strings.HasPrefix(content, "[TOOL_CALL]") {
return nil, fmt.Errorf("不是有效的工具调用格式")
}
parts := strings.Split(content[len("[TOOL_CALL]"):], ":")
if len(parts) < 2 {
return nil, fmt.Errorf("工具调用格式错误")
}
toolName := strings.TrimSpace(parts[0])
argsStr := strings.TrimSpace(parts[1])
args := make(map[string]interface{})
argPairs := strings.Split(argsStr, ",")
for _, pair := range argPairs {
kv := strings.Split(pair, "=")
if len(kv) == 2 {
args[strings.TrimSpace(kv[0])] = strings.TrimSpace(kv[1])
}
}
args["_tool_name"] = toolName
return args, nil
}
// ToolExecutionResult 工具执行结果
type ToolExecutionResult struct {
Result string
ExecutionID string
}
// executeToolViaMCP 通过MCP执行工具
func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map[string]interface{}) (*ToolExecutionResult, error) {
a.logger.Info("通过MCP执行工具",
zap.String("tool", toolName),
zap.Any("args", args),
)
// 通过MCP服务器调用工具
result, executionID, err := a.mcpServer.CallTool(ctx, toolName, args)
if err != nil {
return nil, fmt.Errorf("工具执行失败: %w", err)
}
// 格式化结果
var resultText strings.Builder
for _, content := range result.Content {
resultText.WriteString(content.Text)
resultText.WriteString("\n")
}
return &ToolExecutionResult{
Result: resultText.String(),
ExecutionID: executionID,
}, nil
}

163
internal/app/app.go Normal file
View File

@@ -0,0 +1,163 @@
package app
import (
"fmt"
"net/http"
"os"
"path/filepath"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/handler"
"cyberstrike-ai/internal/logger"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/security"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// App 应用
type App struct {
config *config.Config
logger *logger.Logger
router *gin.Engine
mcpServer *mcp.Server
agent *agent.Agent
executor *security.Executor
db *database.DB
}
// New 创建新应用
func New(cfg *config.Config, log *logger.Logger) (*App, error) {
gin.SetMode(gin.ReleaseMode)
router := gin.Default()
// CORS中间件
router.Use(corsMiddleware())
// 初始化数据库
dbPath := cfg.Database.Path
if dbPath == "" {
dbPath = "data/conversations.db"
}
// 确保目录存在
if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil {
return nil, fmt.Errorf("创建数据库目录失败: %w", err)
}
db, err := database.NewDB(dbPath, log.Logger)
if err != nil {
return nil, fmt.Errorf("初始化数据库失败: %w", err)
}
// 创建MCP服务器
mcpServer := mcp.NewServer(log.Logger)
// 创建安全工具执行器
executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger)
// 注册工具
executor.RegisterTools(mcpServer)
// 创建Agent
agent := agent.NewAgent(&cfg.OpenAI, mcpServer, log.Logger)
// 创建处理器
agentHandler := handler.NewAgentHandler(agent, db, log.Logger)
monitorHandler := handler.NewMonitorHandler(mcpServer, executor, log.Logger)
conversationHandler := handler.NewConversationHandler(db, log.Logger)
// 设置路由
setupRoutes(router, agentHandler, monitorHandler, conversationHandler, mcpServer)
return &App{
config: cfg,
logger: log,
router: router,
mcpServer: mcpServer,
agent: agent,
executor: executor,
db: db,
}, nil
}
// Run 启动应用
func (a *App) Run() error {
// 启动MCP服务器如果启用
if a.config.MCP.Enabled {
go func() {
mcpAddr := fmt.Sprintf("%s:%d", a.config.MCP.Host, a.config.MCP.Port)
a.logger.Info("启动MCP服务器", zap.String("address", mcpAddr))
mux := http.NewServeMux()
mux.HandleFunc("/mcp", a.mcpServer.HandleHTTP)
if err := http.ListenAndServe(mcpAddr, mux); err != nil {
a.logger.Error("MCP服务器启动失败", zap.Error(err))
}
}()
}
// 启动主服务器
addr := fmt.Sprintf("%s:%d", a.config.Server.Host, a.config.Server.Port)
a.logger.Info("启动HTTP服务器", zap.String("address", addr))
return a.router.Run(addr)
}
// setupRoutes 设置路由
func setupRoutes(router *gin.Engine, agentHandler *handler.AgentHandler, monitorHandler *handler.MonitorHandler, conversationHandler *handler.ConversationHandler, mcpServer *mcp.Server) {
// API路由
api := router.Group("/api")
{
// Agent Loop
api.POST("/agent-loop", agentHandler.AgentLoop)
// 对话历史
api.POST("/conversations", conversationHandler.CreateConversation)
api.GET("/conversations", conversationHandler.ListConversations)
api.GET("/conversations/:id", conversationHandler.GetConversation)
api.DELETE("/conversations/:id", conversationHandler.DeleteConversation)
// 监控
api.GET("/monitor", monitorHandler.Monitor)
api.GET("/monitor/execution/:id", monitorHandler.GetExecution)
api.GET("/monitor/stats", monitorHandler.GetStats)
api.GET("/monitor/vulnerabilities", monitorHandler.GetVulnerabilities)
// MCP端点
api.POST("/mcp", func(c *gin.Context) {
mcpServer.HandleHTTP(c.Writer, c.Request)
})
}
// 静态文件
router.Static("/static", "./web/static")
router.LoadHTMLGlob("web/templates/*")
// 前端页面
router.GET("/", func(c *gin.Context) {
c.HTML(http.StatusOK, "index.html", nil)
})
}
// corsMiddleware CORS中间件
func corsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE")
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
return
}
c.Next()
}
}

114
internal/config/config.go Normal file
View File

@@ -0,0 +1,114 @@
package config
import (
"fmt"
"os"
"gopkg.in/yaml.v3"
)
type Config struct {
Server ServerConfig `yaml:"server"`
Log LogConfig `yaml:"log"`
MCP MCPConfig `yaml:"mcp"`
OpenAI OpenAIConfig `yaml:"openai"`
Security SecurityConfig `yaml:"security"`
Database DatabaseConfig `yaml:"database"`
}
type ServerConfig struct {
Host string `yaml:"host"`
Port int `yaml:"port"`
}
type LogConfig struct {
Level string `yaml:"level"`
Output string `yaml:"output"`
}
type MCPConfig struct {
Enabled bool `yaml:"enabled"`
Host string `yaml:"host"`
Port int `yaml:"port"`
}
type OpenAIConfig struct {
APIKey string `yaml:"api_key"`
BaseURL string `yaml:"base_url"`
Model string `yaml:"model"`
}
type SecurityConfig struct {
Tools []ToolConfig `yaml:"tools"`
}
type DatabaseConfig struct {
Path string `yaml:"path"`
}
type ToolConfig struct {
Name string `yaml:"name"`
Command string `yaml:"command"`
Args []string `yaml:"args,omitempty"` // 固定参数(可选)
Description string `yaml:"description"`
Enabled bool `yaml:"enabled"`
Parameters []ParameterConfig `yaml:"parameters,omitempty"` // 参数定义(可选)
ArgMapping string `yaml:"arg_mapping,omitempty"` // 参数映射方式: "auto", "manual", "template"(可选)
}
// ParameterConfig 参数配置
type ParameterConfig struct {
Name string `yaml:"name"` // 参数名称
Type string `yaml:"type"` // 参数类型: string, int, bool, array
Description string `yaml:"description"` // 参数描述
Required bool `yaml:"required,omitempty"` // 是否必需
Default interface{} `yaml:"default,omitempty"` // 默认值
Flag string `yaml:"flag,omitempty"` // 命令行标志,如 "-u", "--url", "-p"
Position *int `yaml:"position,omitempty"` // 位置参数的位置从0开始
Format string `yaml:"format,omitempty"` // 参数格式: "flag", "positional", "combined" (flag=value), "template"
Template string `yaml:"template,omitempty"` // 模板字符串,如 "{flag} {value}" 或 "{value}"
Options []string `yaml:"options,omitempty"` // 可选值列表(用于枚举)
}
func Load(path string) (*Config, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("读取配置文件失败: %w", err)
}
var cfg Config
if err := yaml.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("解析配置文件失败: %w", err)
}
return &cfg, nil
}
func Default() *Config {
return &Config{
Server: ServerConfig{
Host: "0.0.0.0",
Port: 8080,
},
Log: LogConfig{
Level: "info",
Output: "stdout",
},
MCP: MCPConfig{
Enabled: true,
Host: "0.0.0.0",
Port: 8081,
},
OpenAI: OpenAIConfig{
BaseURL: "https://api.openai.com/v1",
Model: "gpt-4",
},
Security: SecurityConfig{
Tools: []ToolConfig{}, // 工具配置应该从 config.yaml 加载,不在此硬编码
},
Database: DatabaseConfig{
Path: "data/conversations.db",
},
}
}

View File

@@ -0,0 +1,256 @@
package database
import (
"database/sql"
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
)
// Conversation 对话
type Conversation struct {
ID string `json:"id"`
Title string `json:"title"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
Messages []Message `json:"messages,omitempty"`
}
// Message 消息
type Message struct {
ID string `json:"id"`
ConversationID string `json:"conversationId"`
Role string `json:"role"`
Content string `json:"content"`
MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"`
CreatedAt time.Time `json:"createdAt"`
}
// CreateConversation 创建新对话
func (db *DB) CreateConversation(title string) (*Conversation, error) {
id := uuid.New().String()
now := time.Now()
_, err := db.Exec(
"INSERT INTO conversations (id, title, created_at, updated_at) VALUES (?, ?, ?, ?)",
id, title, now, now,
)
if err != nil {
return nil, fmt.Errorf("创建对话失败: %w", err)
}
return &Conversation{
ID: id,
Title: title,
CreatedAt: now,
UpdatedAt: now,
}, nil
}
// GetConversation 获取对话
func (db *DB) GetConversation(id string) (*Conversation, error) {
var conv Conversation
var createdAt, updatedAt string
err := db.QueryRow(
"SELECT id, title, created_at, updated_at FROM conversations WHERE id = ?",
id,
).Scan(&conv.ID, &conv.Title, &createdAt, &updatedAt)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("对话不存在")
}
return nil, fmt.Errorf("查询对话失败: %w", err)
}
// 尝试多种时间格式解析
var err1, err2 error
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
if err1 != nil {
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
}
if err1 != nil {
conv.CreatedAt, err1 = time.Parse(time.RFC3339, createdAt)
}
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
if err2 != nil {
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
}
if err2 != nil {
conv.UpdatedAt, err2 = time.Parse(time.RFC3339, updatedAt)
}
// 加载消息
messages, err := db.GetMessages(id)
if err != nil {
return nil, fmt.Errorf("加载消息失败: %w", err)
}
conv.Messages = messages
return &conv, nil
}
// ListConversations 列出所有对话
func (db *DB) ListConversations(limit, offset int) ([]*Conversation, error) {
rows, err := db.Query(
"SELECT id, title, created_at, updated_at FROM conversations ORDER BY updated_at DESC LIMIT ? OFFSET ?",
limit, offset,
)
if err != nil {
return nil, fmt.Errorf("查询对话列表失败: %w", err)
}
defer rows.Close()
var conversations []*Conversation
for rows.Next() {
var conv Conversation
var createdAt, updatedAt string
if err := rows.Scan(&conv.ID, &conv.Title, &createdAt, &updatedAt); err != nil {
return nil, fmt.Errorf("扫描对话失败: %w", err)
}
// 尝试多种时间格式解析
var err1, err2 error
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
if err1 != nil {
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
}
if err1 != nil {
conv.CreatedAt, err1 = time.Parse(time.RFC3339, createdAt)
}
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
if err2 != nil {
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
}
if err2 != nil {
conv.UpdatedAt, err2 = time.Parse(time.RFC3339, updatedAt)
}
conversations = append(conversations, &conv)
}
return conversations, nil
}
// UpdateConversationTitle 更新对话标题
func (db *DB) UpdateConversationTitle(id, title string) error {
_, err := db.Exec(
"UPDATE conversations SET title = ?, updated_at = ? WHERE id = ?",
title, time.Now(), id,
)
if err != nil {
return fmt.Errorf("更新对话标题失败: %w", err)
}
return nil
}
// UpdateConversationTime 更新对话时间
func (db *DB) UpdateConversationTime(id string) error {
_, err := db.Exec(
"UPDATE conversations SET updated_at = ? WHERE id = ?",
time.Now(), id,
)
if err != nil {
return fmt.Errorf("更新对话时间失败: %w", err)
}
return nil
}
// DeleteConversation 删除对话
func (db *DB) DeleteConversation(id string) error {
_, err := db.Exec("DELETE FROM conversations WHERE id = ?", id)
if err != nil {
return fmt.Errorf("删除对话失败: %w", err)
}
return nil
}
// AddMessage 添加消息
func (db *DB) AddMessage(conversationID, role, content string, mcpExecutionIDs []string) (*Message, error) {
id := uuid.New().String()
var mcpIDsJSON string
if len(mcpExecutionIDs) > 0 {
jsonData, err := json.Marshal(mcpExecutionIDs)
if err != nil {
db.logger.Warn("序列化MCP执行ID失败", zap.Error(err))
} else {
mcpIDsJSON = string(jsonData)
}
}
_, err := db.Exec(
"INSERT INTO messages (id, conversation_id, role, content, mcp_execution_ids, created_at) VALUES (?, ?, ?, ?, ?, ?)",
id, conversationID, role, content, mcpIDsJSON, time.Now(),
)
if err != nil {
return nil, fmt.Errorf("添加消息失败: %w", err)
}
// 更新对话时间
if err := db.UpdateConversationTime(conversationID); err != nil {
db.logger.Warn("更新对话时间失败", zap.Error(err))
}
message := &Message{
ID: id,
ConversationID: conversationID,
Role: role,
Content: content,
MCPExecutionIDs: mcpExecutionIDs,
CreatedAt: time.Now(),
}
return message, nil
}
// GetMessages 获取对话的所有消息
func (db *DB) GetMessages(conversationID string) ([]Message, error) {
rows, err := db.Query(
"SELECT id, conversation_id, role, content, mcp_execution_ids, created_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC",
conversationID,
)
if err != nil {
return nil, fmt.Errorf("查询消息失败: %w", err)
}
defer rows.Close()
var messages []Message
for rows.Next() {
var msg Message
var mcpIDsJSON sql.NullString
var createdAt string
if err := rows.Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &mcpIDsJSON, &createdAt); err != nil {
return nil, fmt.Errorf("扫描消息失败: %w", err)
}
// 尝试多种时间格式解析
var err error
msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
if err != nil {
msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05", createdAt)
}
if err != nil {
msg.CreatedAt, err = time.Parse(time.RFC3339, createdAt)
}
// 解析MCP执行ID
if mcpIDsJSON.Valid && mcpIDsJSON.String != "" {
if err := json.Unmarshal([]byte(mcpIDsJSON.String), &msg.MCPExecutionIDs); err != nil {
db.logger.Warn("解析MCP执行ID失败", zap.Error(err))
}
}
messages = append(messages, msg)
}
return messages, nil
}

View File

@@ -0,0 +1,90 @@
package database
import (
"database/sql"
"fmt"
_ "github.com/mattn/go-sqlite3"
"go.uber.org/zap"
)
// DB 数据库连接
type DB struct {
*sql.DB
logger *zap.Logger
}
// NewDB 创建数据库连接
func NewDB(dbPath string, logger *zap.Logger) (*DB, error) {
db, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1")
if err != nil {
return nil, fmt.Errorf("打开数据库失败: %w", err)
}
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("连接数据库失败: %w", err)
}
database := &DB{
DB: db,
logger: logger,
}
// 初始化表
if err := database.initTables(); err != nil {
return nil, fmt.Errorf("初始化表失败: %w", err)
}
return database, nil
}
// initTables 初始化数据库表
func (db *DB) initTables() error {
// 创建对话表
createConversationsTable := `
CREATE TABLE IF NOT EXISTS conversations (
id TEXT PRIMARY KEY,
title TEXT NOT NULL,
created_at DATETIME NOT NULL,
updated_at DATETIME NOT NULL
);`
// 创建消息表
createMessagesTable := `
CREATE TABLE IF NOT EXISTS messages (
id TEXT PRIMARY KEY,
conversation_id TEXT NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
mcp_execution_ids TEXT,
created_at DATETIME NOT NULL,
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
);`
// 创建索引
createIndexes := `
CREATE INDEX IF NOT EXISTS idx_messages_conversation_id ON messages(conversation_id);
CREATE INDEX IF NOT EXISTS idx_conversations_updated_at ON conversations(updated_at);
`
if _, err := db.Exec(createConversationsTable); err != nil {
return fmt.Errorf("创建conversations表失败: %w", err)
}
if _, err := db.Exec(createMessagesTable); err != nil {
return fmt.Errorf("创建messages表失败: %w", err)
}
if _, err := db.Exec(createIndexes); err != nil {
return fmt.Errorf("创建索引失败: %w", err)
}
db.logger.Info("数据库表初始化完成")
return nil
}
// Close 关闭数据库连接
func (db *DB) Close() error {
return db.DB.Close()
}

134
internal/handler/agent.go Normal file
View File

@@ -0,0 +1,134 @@
package handler
import (
"net/http"
"time"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// AgentHandler Agent处理器
type AgentHandler struct {
agent *agent.Agent
db *database.DB
logger *zap.Logger
}
// NewAgentHandler 创建新的Agent处理器
func NewAgentHandler(agent *agent.Agent, db *database.DB, logger *zap.Logger) *AgentHandler {
return &AgentHandler{
agent: agent,
db: db,
logger: logger,
}
}
// ChatRequest 聊天请求
type ChatRequest struct {
Message string `json:"message" binding:"required"`
ConversationID string `json:"conversationId,omitempty"`
}
// ChatResponse 聊天响应
type ChatResponse struct {
Response string `json:"response"`
MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"` // 本次对话中执行的MCP调用ID列表
ConversationID string `json:"conversationId"` // 对话ID
Time time.Time `json:"time"`
}
// AgentLoop 处理Agent Loop请求
func (h *AgentHandler) AgentLoop(c *gin.Context) {
var req ChatRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
h.logger.Info("收到Agent Loop请求",
zap.String("message", req.Message),
zap.String("conversationId", req.ConversationID),
)
// 如果没有对话ID创建新对话
conversationID := req.ConversationID
if conversationID == "" {
title := req.Message
if len(title) > 50 {
title = title[:50] + "..."
}
conv, err := h.db.CreateConversation(title)
if err != nil {
h.logger.Error("创建对话失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
conversationID = conv.ID
}
// 获取历史消息(排除当前消息,因为还没保存)
historyMessages, err := h.db.GetMessages(conversationID)
if err != nil {
h.logger.Warn("获取历史消息失败", zap.Error(err))
historyMessages = []database.Message{}
}
h.logger.Info("获取历史消息",
zap.String("conversationId", conversationID),
zap.Int("count", len(historyMessages)),
)
// 将数据库消息转换为Agent消息格式
agentHistoryMessages := make([]agent.ChatMessage, 0, len(historyMessages))
for i, msg := range historyMessages {
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
Role: msg.Role,
Content: msg.Content,
})
contentPreview := msg.Content
if len(contentPreview) > 50 {
contentPreview = contentPreview[:50] + "..."
}
h.logger.Info("添加历史消息",
zap.Int("index", i),
zap.String("role", msg.Role),
zap.String("content", contentPreview),
)
}
h.logger.Info("历史消息转换完成",
zap.Int("originalCount", len(historyMessages)),
zap.Int("convertedCount", len(agentHistoryMessages)),
)
// 保存用户消息
_, err = h.db.AddMessage(conversationID, "user", req.Message, nil)
if err != nil {
h.logger.Error("保存用户消息失败", zap.Error(err))
}
// 执行Agent Loop传入历史消息
result, err := h.agent.AgentLoop(c.Request.Context(), req.Message, agentHistoryMessages)
if err != nil {
h.logger.Error("Agent Loop执行失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 保存助手回复
_, err = h.db.AddMessage(conversationID, "assistant", result.Response, result.MCPExecutionIDs)
if err != nil {
h.logger.Error("保存助手消息失败", zap.Error(err))
}
c.JSON(http.StatusOK, ChatResponse{
Response: result.Response,
MCPExecutionIDs: result.MCPExecutionIDs,
ConversationID: conversationID,
Time: time.Now(),
})
}

View File

@@ -0,0 +1,102 @@
package handler
import (
"net/http"
"strconv"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// ConversationHandler 对话处理器
type ConversationHandler struct {
db *database.DB
logger *zap.Logger
}
// NewConversationHandler 创建新的对话处理器
func NewConversationHandler(db *database.DB, logger *zap.Logger) *ConversationHandler {
return &ConversationHandler{
db: db,
logger: logger,
}
}
// CreateConversationRequest 创建对话请求
type CreateConversationRequest struct {
Title string `json:"title"`
}
// CreateConversation 创建新对话
func (h *ConversationHandler) CreateConversation(c *gin.Context) {
var req CreateConversationRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
title := req.Title
if title == "" {
title = "新对话"
}
conv, err := h.db.CreateConversation(title)
if err != nil {
h.logger.Error("创建对话失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, conv)
}
// ListConversations 列出对话
func (h *ConversationHandler) ListConversations(c *gin.Context) {
limitStr := c.DefaultQuery("limit", "50")
offsetStr := c.DefaultQuery("offset", "0")
limit, _ := strconv.Atoi(limitStr)
offset, _ := strconv.Atoi(offsetStr)
if limit <= 0 || limit > 100 {
limit = 50
}
conversations, err := h.db.ListConversations(limit, offset)
if err != nil {
h.logger.Error("获取对话列表失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, conversations)
}
// GetConversation 获取对话
func (h *ConversationHandler) GetConversation(c *gin.Context) {
id := c.Param("id")
conv, err := h.db.GetConversation(id)
if err != nil {
h.logger.Error("获取对话失败", zap.Error(err))
c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"})
return
}
c.JSON(http.StatusOK, conv)
}
// DeleteConversation 删除对话
func (h *ConversationHandler) DeleteConversation(c *gin.Context) {
id := c.Param("id")
if err := h.db.DeleteConversation(id); err != nil {
h.logger.Error("删除对话失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
}

View File

@@ -0,0 +1,92 @@
package handler
import (
"net/http"
"time"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/security"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// MonitorHandler 监控处理器
type MonitorHandler struct {
mcpServer *mcp.Server
executor *security.Executor
logger *zap.Logger
vulns []security.Vulnerability
}
// NewMonitorHandler 创建新的监控处理器
func NewMonitorHandler(mcpServer *mcp.Server, executor *security.Executor, logger *zap.Logger) *MonitorHandler {
return &MonitorHandler{
mcpServer: mcpServer,
executor: executor,
logger: logger,
vulns: []security.Vulnerability{},
}
}
// MonitorResponse 监控响应
type MonitorResponse struct {
Executions []*mcp.ToolExecution `json:"executions"`
Stats map[string]*mcp.ToolStats `json:"stats"`
Vulnerabilities []security.Vulnerability `json:"vulnerabilities"`
Report map[string]interface{} `json:"report"`
Timestamp time.Time `json:"timestamp"`
}
// Monitor 获取监控信息
func (h *MonitorHandler) Monitor(c *gin.Context) {
// 获取所有执行记录
executions := h.mcpServer.GetAllExecutions()
// 分析执行结果,提取漏洞
for _, exec := range executions {
if exec.Status == "completed" && exec.Result != nil {
vulns := h.executor.AnalyzeResults(exec.ToolName, exec.Result)
h.vulns = append(h.vulns, vulns...)
}
}
// 获取统计信息
stats := h.mcpServer.GetStats()
// 生成报告
report := h.executor.GetVulnerabilityReport(h.vulns)
c.JSON(http.StatusOK, MonitorResponse{
Executions: executions,
Stats: stats,
Vulnerabilities: h.vulns,
Report: report,
Timestamp: time.Now(),
})
}
// GetExecution 获取特定执行记录
func (h *MonitorHandler) GetExecution(c *gin.Context) {
id := c.Param("id")
exec, exists := h.mcpServer.GetExecution(id)
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "执行记录未找到"})
return
}
c.JSON(http.StatusOK, exec)
}
// GetStats 获取统计信息
func (h *MonitorHandler) GetStats(c *gin.Context) {
stats := h.mcpServer.GetStats()
c.JSON(http.StatusOK, stats)
}
// GetVulnerabilities 获取漏洞列表
func (h *MonitorHandler) GetVulnerabilities(c *gin.Context) {
report := h.executor.GetVulnerabilityReport(h.vulns)
c.JSON(http.StatusOK, report)
}

60
internal/logger/logger.go Normal file
View File

@@ -0,0 +1,60 @@
package logger
import (
"os"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
type Logger struct {
*zap.Logger
}
func New(level, output string) *Logger {
var zapLevel zapcore.Level
switch level {
case "debug":
zapLevel = zapcore.DebugLevel
case "info":
zapLevel = zapcore.InfoLevel
case "warn":
zapLevel = zapcore.WarnLevel
case "error":
zapLevel = zapcore.ErrorLevel
default:
zapLevel = zapcore.InfoLevel
}
config := zap.NewProductionConfig()
config.Level = zap.NewAtomicLevelAt(zapLevel)
config.EncoderConfig.TimeKey = "timestamp"
config.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
var writeSyncer zapcore.WriteSyncer
if output == "stdout" {
writeSyncer = zapcore.AddSync(os.Stdout)
} else {
file, err := os.OpenFile(output, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
if err != nil {
writeSyncer = zapcore.AddSync(os.Stdout)
} else {
writeSyncer = zapcore.AddSync(file)
}
}
core := zapcore.NewCore(
zapcore.NewJSONEncoder(config.EncoderConfig),
writeSyncer,
zapLevel,
)
logger := zap.New(core, zap.AddCaller(), zap.AddStacktrace(zapcore.ErrorLevel))
return &Logger{Logger: logger}
}
func (l *Logger) Fatal(msg string, fields ...interface{}) {
l.Logger.Fatal(msg, zap.Any("fields", fields))
}

798
internal/mcp/server.go Normal file
View File

@@ -0,0 +1,798 @@
package mcp
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
)
// Server MCP服务器
type Server struct {
tools map[string]ToolHandler
toolDefs map[string]Tool // 工具定义
executions map[string]*ToolExecution
stats map[string]*ToolStats
prompts map[string]*Prompt // 提示词模板
resources map[string]*Resource // 资源
mu sync.RWMutex
logger *zap.Logger
}
// ToolHandler 工具处理函数
type ToolHandler func(ctx context.Context, args map[string]interface{}) (*ToolResult, error)
// NewServer 创建新的MCP服务器
func NewServer(logger *zap.Logger) *Server {
s := &Server{
tools: make(map[string]ToolHandler),
toolDefs: make(map[string]Tool),
executions: make(map[string]*ToolExecution),
stats: make(map[string]*ToolStats),
prompts: make(map[string]*Prompt),
resources: make(map[string]*Resource),
logger: logger,
}
// 初始化默认提示词和资源
s.initDefaultPrompts()
s.initDefaultResources()
return s
}
// RegisterTool 注册工具
func (s *Server) RegisterTool(tool Tool, handler ToolHandler) {
s.mu.Lock()
defer s.mu.Unlock()
s.tools[tool.Name] = handler
s.toolDefs[tool.Name] = tool
// 自动为工具创建资源文档
resourceURI := fmt.Sprintf("tool://%s", tool.Name)
s.resources[resourceURI] = &Resource{
URI: resourceURI,
Name: fmt.Sprintf("%s工具文档", tool.Name),
Description: tool.Description,
MimeType: "text/plain",
}
}
// HandleHTTP 处理HTTP请求
func (s *Server) HandleHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
s.sendError(w, nil, -32700, "Parse error", err.Error())
return
}
var msg Message
if err := json.Unmarshal(body, &msg); err != nil {
s.sendError(w, nil, -32700, "Parse error", err.Error())
return
}
// 处理消息
response := s.handleMessage(&msg)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
// handleMessage 处理MCP消息
func (s *Server) handleMessage(msg *Message) *Message {
if msg.ID == "" {
msg.ID = uuid.New().String()
}
switch msg.Method {
case "initialize":
return s.handleInitialize(msg)
case "tools/list":
return s.handleListTools(msg)
case "tools/call":
return s.handleCallTool(msg)
case "prompts/list":
return s.handleListPrompts(msg)
case "prompts/get":
return s.handleGetPrompt(msg)
case "resources/list":
return s.handleListResources(msg)
case "resources/read":
return s.handleReadResource(msg)
case "sampling/request":
return s.handleSamplingRequest(msg)
default:
return &Message{
ID: msg.ID,
Type: MessageTypeError,
Error: &Error{Code: -32601, Message: "Method not found"},
}
}
}
// handleInitialize 处理初始化请求
func (s *Server) handleInitialize(msg *Message) *Message {
var req InitializeRequest
if err := json.Unmarshal(msg.Params, &req); err != nil {
return &Message{
ID: msg.ID,
Type: MessageTypeError,
Error: &Error{Code: -32602, Message: "Invalid params"},
}
}
response := InitializeResponse{
ProtocolVersion: ProtocolVersion,
Capabilities: ServerCapabilities{
Tools: map[string]interface{}{
"listChanged": true,
},
Prompts: map[string]interface{}{
"listChanged": true,
},
Resources: map[string]interface{}{
"subscribe": true,
"listChanged": true,
},
Sampling: map[string]interface{}{},
},
ServerInfo: ServerInfo{
Name: "CyberStrikeAI",
Version: "1.0.0",
},
}
result, _ := json.Marshal(response)
return &Message{
ID: msg.ID,
Type: MessageTypeResponse,
Version: "2.0",
Result: result,
}
}
// handleListTools 处理列出工具请求
func (s *Server) handleListTools(msg *Message) *Message {
s.mu.RLock()
tools := make([]Tool, 0, len(s.toolDefs))
for _, tool := range s.toolDefs {
tools = append(tools, tool)
}
s.mu.RUnlock()
response := ListToolsResponse{Tools: tools}
result, _ := json.Marshal(response)
return &Message{
ID: msg.ID,
Type: MessageTypeResponse,
Version: "2.0",
Result: result,
}
}
// handleCallTool 处理工具调用请求
func (s *Server) handleCallTool(msg *Message) *Message {
var req CallToolRequest
if err := json.Unmarshal(msg.Params, &req); err != nil {
return &Message{
ID: msg.ID,
Type: MessageTypeError,
Error: &Error{Code: -32602, Message: "Invalid params"},
}
}
// 创建执行记录
executionID := uuid.New().String()
execution := &ToolExecution{
ID: executionID,
ToolName: req.Name,
Arguments: req.Arguments,
Status: "running",
StartTime: time.Now(),
}
s.mu.Lock()
s.executions[executionID] = execution
s.mu.Unlock()
// 更新统计
s.updateStats(req.Name, false)
// 执行工具
s.mu.RLock()
handler, exists := s.tools[req.Name]
s.mu.RUnlock()
if !exists {
execution.Status = "failed"
execution.Error = "Tool not found"
now := time.Now()
execution.EndTime = &now
return &Message{
ID: msg.ID,
Type: MessageTypeError,
Error: &Error{Code: -32601, Message: "Tool not found"},
}
}
// 同步执行所有工具,确保错误能正确返回
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
s.logger.Info("开始执行工具",
zap.String("toolName", req.Name),
zap.Any("arguments", req.Arguments),
)
result, err := handler(ctx, req.Arguments)
s.mu.Lock()
now := time.Now()
execution.EndTime = &now
execution.Duration = now.Sub(execution.StartTime)
if err != nil {
execution.Status = "failed"
execution.Error = err.Error()
s.updateStats(req.Name, true)
s.mu.Unlock()
s.logger.Error("工具执行失败",
zap.String("toolName", req.Name),
zap.Error(err),
)
// 返回错误结果
errorResult, _ := json.Marshal(CallToolResponse{
Content: []Content{
{Type: "text", Text: fmt.Sprintf("工具执行失败: %v", err)},
},
IsError: true,
})
return &Message{
ID: msg.ID,
Type: MessageTypeResponse,
Version: "2.0",
Result: errorResult,
}
}
// 检查result是否为错误
if result != nil && result.IsError {
execution.Status = "failed"
if len(result.Content) > 0 {
execution.Error = result.Content[0].Text
}
s.updateStats(req.Name, true)
} else {
execution.Status = "completed"
execution.Result = result
s.updateStats(req.Name, false)
}
s.mu.Unlock()
// 返回执行结果
if result == nil {
result = &ToolResult{
Content: []Content{
{Type: "text", Text: "工具执行完成,但未返回结果"},
},
}
}
resultJSON, _ := json.Marshal(CallToolResponse{
Content: result.Content,
IsError: result.IsError,
})
s.logger.Info("工具执行完成",
zap.String("toolName", req.Name),
zap.Bool("isError", result.IsError),
)
return &Message{
ID: msg.ID,
Type: MessageTypeResponse,
Version: "2.0",
Result: resultJSON,
}
}
// updateStats 更新统计信息
func (s *Server) updateStats(toolName string, failed bool) {
if s.stats[toolName] == nil {
s.stats[toolName] = &ToolStats{
ToolName: toolName,
}
}
stats := s.stats[toolName]
stats.TotalCalls++
now := time.Now()
stats.LastCallTime = &now
if failed {
stats.FailedCalls++
} else {
stats.SuccessCalls++
}
}
// GetExecution 获取执行记录
func (s *Server) GetExecution(id string) (*ToolExecution, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
exec, exists := s.executions[id]
return exec, exists
}
// GetAllExecutions 获取所有执行记录
func (s *Server) GetAllExecutions() []*ToolExecution {
s.mu.RLock()
defer s.mu.RUnlock()
executions := make([]*ToolExecution, 0, len(s.executions))
for _, exec := range s.executions {
executions = append(executions, exec)
}
return executions
}
// GetStats 获取统计信息
func (s *Server) GetStats() map[string]*ToolStats {
s.mu.RLock()
defer s.mu.RUnlock()
stats := make(map[string]*ToolStats)
for k, v := range s.stats {
stats[k] = v
}
return stats
}
// CallTool 直接调用工具(用于内部调用)
func (s *Server) CallTool(ctx context.Context, toolName string, args map[string]interface{}) (*ToolResult, string, error) {
s.mu.RLock()
handler, exists := s.tools[toolName]
s.mu.RUnlock()
if !exists {
return nil, "", fmt.Errorf("工具 %s 未找到", toolName)
}
// 创建执行记录
executionID := uuid.New().String()
execution := &ToolExecution{
ID: executionID,
ToolName: toolName,
Arguments: args,
Status: "running",
StartTime: time.Now(),
}
s.mu.Lock()
s.executions[executionID] = execution
s.mu.Unlock()
// 更新统计
s.updateStats(toolName, false)
// 执行工具
result, err := handler(ctx, args)
s.mu.Lock()
now := time.Now()
execution.EndTime = &now
execution.Duration = now.Sub(execution.StartTime)
if err != nil {
execution.Status = "failed"
execution.Error = err.Error()
s.updateStats(toolName, true)
s.mu.Unlock()
return nil, executionID, err
} else {
execution.Status = "completed"
execution.Result = result
s.updateStats(toolName, false)
s.mu.Unlock()
return result, executionID, nil
}
}
// initDefaultPrompts 初始化默认提示词模板
func (s *Server) initDefaultPrompts() {
s.mu.Lock()
defer s.mu.Unlock()
// 网络安全测试提示词
s.prompts["security_scan"] = &Prompt{
Name: "security_scan",
Description: "生成网络安全扫描任务的提示词",
Arguments: []PromptArgument{
{Name: "target", Description: "扫描目标IP地址或域名", Required: true},
{Name: "scan_type", Description: "扫描类型port, vuln, web等", Required: false},
},
}
// 渗透测试提示词
s.prompts["penetration_test"] = &Prompt{
Name: "penetration_test",
Description: "生成渗透测试任务的提示词",
Arguments: []PromptArgument{
{Name: "target", Description: "测试目标", Required: true},
{Name: "scope", Description: "测试范围", Required: false},
},
}
}
// initDefaultResources 初始化默认资源
// 注意:工具资源现在在 RegisterTool 时自动创建,此函数保留用于其他非工具资源
func (s *Server) initDefaultResources() {
// 工具资源已改为在 RegisterTool 时自动创建,无需在此硬编码
}
// handleListPrompts 处理列出提示词请求
func (s *Server) handleListPrompts(msg *Message) *Message {
s.mu.RLock()
prompts := make([]Prompt, 0, len(s.prompts))
for _, prompt := range s.prompts {
prompts = append(prompts, *prompt)
}
s.mu.RUnlock()
response := ListPromptsResponse{
Prompts: prompts,
}
result, _ := json.Marshal(response)
return &Message{
ID: msg.ID,
Type: MessageTypeResponse,
Version: "2.0",
Result: result,
}
}
// handleGetPrompt 处理获取提示词请求
func (s *Server) handleGetPrompt(msg *Message) *Message {
var req GetPromptRequest
if err := json.Unmarshal(msg.Params, &req); err != nil {
return &Message{
ID: msg.ID,
Type: MessageTypeError,
Error: &Error{Code: -32602, Message: "Invalid params"},
}
}
s.mu.RLock()
prompt, exists := s.prompts[req.Name]
s.mu.RUnlock()
if !exists {
return &Message{
ID: msg.ID,
Type: MessageTypeError,
Error: &Error{Code: -32601, Message: "Prompt not found"},
}
}
// 根据提示词名称生成消息
messages := s.generatePromptMessages(prompt, req.Arguments)
response := GetPromptResponse{
Messages: messages,
}
result, _ := json.Marshal(response)
return &Message{
ID: msg.ID,
Type: MessageTypeResponse,
Version: "2.0",
Result: result,
}
}
// generatePromptMessages 生成提示词消息
func (s *Server) generatePromptMessages(prompt *Prompt, args map[string]interface{}) []PromptMessage {
messages := []PromptMessage{}
switch prompt.Name {
case "security_scan":
target, _ := args["target"].(string)
scanType, _ := args["scan_type"].(string)
if scanType == "" {
scanType = "comprehensive"
}
content := fmt.Sprintf(`请对目标 %s 执行%s安全扫描。包括
1. 端口扫描和服务识别
2. 漏洞检测
3. Web应用安全测试
4. 生成详细的安全报告`, target, scanType)
messages = append(messages, PromptMessage{
Role: "user",
Content: content,
})
case "penetration_test":
target, _ := args["target"].(string)
scope, _ := args["scope"].(string)
content := fmt.Sprintf(`请对目标 %s 执行渗透测试。`, target)
if scope != "" {
content += fmt.Sprintf("测试范围:%s", scope)
}
content += "\n请按照OWASP Top 10进行全面的安全测试。"
messages = append(messages, PromptMessage{
Role: "user",
Content: content,
})
default:
messages = append(messages, PromptMessage{
Role: "user",
Content: "请执行安全测试任务",
})
}
return messages
}
// handleListResources 处理列出资源请求
func (s *Server) handleListResources(msg *Message) *Message {
s.mu.RLock()
resources := make([]Resource, 0, len(s.resources))
for _, resource := range s.resources {
resources = append(resources, *resource)
}
s.mu.RUnlock()
response := ListResourcesResponse{
Resources: resources,
}
result, _ := json.Marshal(response)
return &Message{
ID: msg.ID,
Type: MessageTypeResponse,
Version: "2.0",
Result: result,
}
}
// handleReadResource 处理读取资源请求
func (s *Server) handleReadResource(msg *Message) *Message {
var req ReadResourceRequest
if err := json.Unmarshal(msg.Params, &req); err != nil {
return &Message{
ID: msg.ID,
Type: MessageTypeError,
Error: &Error{Code: -32602, Message: "Invalid params"},
}
}
s.mu.RLock()
resource, exists := s.resources[req.URI]
s.mu.RUnlock()
if !exists {
return &Message{
ID: msg.ID,
Type: MessageTypeError,
Error: &Error{Code: -32601, Message: "Resource not found"},
}
}
// 生成资源内容
content := s.generateResourceContent(resource)
response := ReadResourceResponse{
Contents: []ResourceContent{content},
}
result, _ := json.Marshal(response)
return &Message{
ID: msg.ID,
Type: MessageTypeResponse,
Version: "2.0",
Result: result,
}
}
// generateResourceContent 生成资源内容
func (s *Server) generateResourceContent(resource *Resource) ResourceContent {
content := ResourceContent{
URI: resource.URI,
MimeType: resource.MimeType,
}
// 如果是工具资源,生成详细文档
if strings.HasPrefix(resource.URI, "tool://") {
toolName := strings.TrimPrefix(resource.URI, "tool://")
content.Text = s.generateToolDocumentation(toolName, resource)
} else {
// 其他资源使用描述或默认内容
content.Text = resource.Description
}
return content
}
// generateToolDocumentation 生成工具文档
func (s *Server) generateToolDocumentation(toolName string, resource *Resource) string {
// 获取工具定义以获取更详细的信息
s.mu.RLock()
tool, hasTool := s.toolDefs[toolName]
s.mu.RUnlock()
// 为常见工具生成详细文档
switch toolName {
case "nmap":
return `Nmap (Network Mapper) 是一个强大的网络扫描工具。
主要功能:
- 端口扫描:发现目标主机开放的端口
- 服务识别:识别运行在端口上的服务
- 版本检测:检测服务版本信息
- 操作系统检测:识别目标操作系统
常用命令:
- nmap -sT target # TCP连接扫描
- nmap -sV target # 版本检测
- nmap -sC target # 默认脚本扫描
- nmap -p 1-1000 target # 扫描指定端口范围
参数说明:
- target: 目标IP地址或域名必需
- ports: 端口范围,例如: 1-1000可选`
case "sqlmap":
return `SQLMap 是一个自动化的SQL注入检测和利用工具。
主要功能:
- 自动检测SQL注入漏洞
- 数据库指纹识别
- 数据提取
- 文件系统访问
常用命令:
- sqlmap -u "http://target.com/page?id=1" # 检测URL参数
- sqlmap -u "http://target.com" --forms # 检测表单
- sqlmap -u "http://target.com" --dbs # 列出数据库
参数说明:
- url: 目标URL必需`
case "nikto":
return `Nikto 是一个Web服务器扫描工具。
主要功能:
- Web服务器漏洞扫描
- 检测过时的服务器软件
- 检测危险文件和程序
- 检测服务器配置问题
常用命令:
- nikto -h target # 扫描目标主机
- nikto -h target -p 80,443 # 扫描指定端口
参数说明:
- target: 目标URL必需`
case "dirb":
return `Dirb 是一个Web内容扫描器。
主要功能:
- 扫描Web目录和文件
- 发现隐藏的目录和文件
- 支持自定义字典
常用命令:
- dirb url # 扫描目标URL
- dirb url -w wordlist.txt # 使用自定义字典
参数说明:
- target: 目标URL必需`
case "exec":
return `Exec 工具用于执行系统命令。
⚠️ 警告:此工具可以执行任意系统命令,请谨慎使用!
参数说明:
- command: 要执行的系统命令(必需)
- shell: 使用的shell默认为sh可选
- workdir: 工作目录(可选)`
default:
// 对于其他工具,使用工具定义中的描述信息
if hasTool {
doc := fmt.Sprintf("%s\n\n", resource.Description)
if tool.InputSchema != nil {
if props, ok := tool.InputSchema["properties"].(map[string]interface{}); ok {
doc += "参数说明:\n"
for paramName, paramInfo := range props {
if paramMap, ok := paramInfo.(map[string]interface{}); ok {
if desc, ok := paramMap["description"].(string); ok {
doc += fmt.Sprintf("- %s: %s\n", paramName, desc)
}
}
}
}
}
return doc
}
return resource.Description
}
}
// handleSamplingRequest 处理采样请求
func (s *Server) handleSamplingRequest(msg *Message) *Message {
var req SamplingRequest
if err := json.Unmarshal(msg.Params, &req); err != nil {
return &Message{
ID: msg.ID,
Type: MessageTypeError,
Error: &Error{Code: -32602, Message: "Invalid params"},
}
}
// 注意采样功能通常需要连接到实际的LLM服务
// 这里返回一个占位符响应实际实现需要集成LLM API
s.logger.Warn("Sampling request received but not fully implemented",
zap.Any("request", req),
)
response := SamplingResponse{
Content: []SamplingContent{
{
Type: "text",
Text: "采样功能需要配置LLM服务。请使用Agent Loop API进行AI对话。",
},
},
StopReason: "length",
}
result, _ := json.Marshal(response)
return &Message{
ID: msg.ID,
Type: MessageTypeResponse,
Version: "2.0",
Result: result,
}
}
// RegisterPrompt 注册提示词模板
func (s *Server) RegisterPrompt(prompt *Prompt) {
s.mu.Lock()
defer s.mu.Unlock()
s.prompts[prompt.Name] = prompt
}
// RegisterResource 注册资源
func (s *Server) RegisterResource(resource *Resource) {
s.mu.Lock()
defer s.mu.Unlock()
s.resources[resource.URI] = resource
}
// sendError 发送错误响应
func (s *Server) sendError(w http.ResponseWriter, id interface{}, code int, message, data string) {
response := Message{
ID: fmt.Sprintf("%v", id),
Type: MessageTypeError,
Error: &Error{Code: code, Message: message, Data: data},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}

232
internal/mcp/types.go Normal file
View File

@@ -0,0 +1,232 @@
package mcp
import (
"encoding/json"
"time"
)
// MCP消息类型
const (
MessageTypeRequest = "request"
MessageTypeResponse = "response"
MessageTypeError = "error"
MessageTypeNotify = "notify"
)
// MCP协议版本
const ProtocolVersion = "2024-11-05"
// Message 表示MCP消息
type Message struct {
ID string `json:"id,omitempty"`
Type string `json:"type"`
Method string `json:"method,omitempty"`
Params json.RawMessage `json:"params,omitempty"`
Result json.RawMessage `json:"result,omitempty"`
Error *Error `json:"error,omitempty"`
Version string `json:"jsonrpc,omitempty"`
}
// Error 表示MCP错误
type Error struct {
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
}
// Tool 表示MCP工具定义
type Tool struct {
Name string `json:"name"`
Description string `json:"description"`
InputSchema map[string]interface{} `json:"inputSchema"`
}
// ToolCall 表示工具调用
type ToolCall struct {
Name string `json:"name"`
Arguments map[string]interface{} `json:"arguments"`
}
// ToolResult 表示工具执行结果
type ToolResult struct {
Content []Content `json:"content"`
IsError bool `json:"isError,omitempty"`
}
// Content 表示内容
type Content struct {
Type string `json:"type"`
Text string `json:"text"`
}
// InitializeRequest 初始化请求
type InitializeRequest struct {
ProtocolVersion string `json:"protocolVersion"`
Capabilities map[string]interface{} `json:"capabilities"`
ClientInfo ClientInfo `json:"clientInfo"`
}
// ClientInfo 客户端信息
type ClientInfo struct {
Name string `json:"name"`
Version string `json:"version"`
}
// InitializeResponse 初始化响应
type InitializeResponse struct {
ProtocolVersion string `json:"protocolVersion"`
Capabilities ServerCapabilities `json:"capabilities"`
ServerInfo ServerInfo `json:"serverInfo"`
}
// ServerCapabilities 服务器能力
type ServerCapabilities struct {
Tools map[string]interface{} `json:"tools,omitempty"`
Prompts map[string]interface{} `json:"prompts,omitempty"`
Resources map[string]interface{} `json:"resources,omitempty"`
Sampling map[string]interface{} `json:"sampling,omitempty"`
}
// ServerInfo 服务器信息
type ServerInfo struct {
Name string `json:"name"`
Version string `json:"version"`
}
// ListToolsRequest 列出工具请求
type ListToolsRequest struct{}
// ListToolsResponse 列出工具响应
type ListToolsResponse struct {
Tools []Tool `json:"tools"`
}
// ListPromptsResponse 列出提示词响应
type ListPromptsResponse struct {
Prompts []Prompt `json:"prompts"`
}
// ListResourcesResponse 列出资源响应
type ListResourcesResponse struct {
Resources []Resource `json:"resources"`
}
// CallToolRequest 调用工具请求
type CallToolRequest struct {
Name string `json:"name"`
Arguments map[string]interface{} `json:"arguments"`
}
// CallToolResponse 调用工具响应
type CallToolResponse struct {
Content []Content `json:"content"`
IsError bool `json:"isError,omitempty"`
}
// ToolExecution 工具执行记录
type ToolExecution struct {
ID string `json:"id"`
ToolName string `json:"toolName"`
Arguments map[string]interface{} `json:"arguments"`
Status string `json:"status"` // pending, running, completed, failed
Result *ToolResult `json:"result,omitempty"`
Error string `json:"error,omitempty"`
StartTime time.Time `json:"startTime"`
EndTime *time.Time `json:"endTime,omitempty"`
Duration time.Duration `json:"duration,omitempty"`
}
// ToolStats 工具统计信息
type ToolStats struct {
ToolName string `json:"toolName"`
TotalCalls int `json:"totalCalls"`
SuccessCalls int `json:"successCalls"`
FailedCalls int `json:"failedCalls"`
LastCallTime *time.Time `json:"lastCallTime,omitempty"`
}
// Prompt 提示词模板
type Prompt struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Arguments []PromptArgument `json:"arguments,omitempty"`
}
// PromptArgument 提示词参数
type PromptArgument struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Required bool `json:"required,omitempty"`
}
// GetPromptRequest 获取提示词请求
type GetPromptRequest struct {
Name string `json:"name"`
Arguments map[string]interface{} `json:"arguments,omitempty"`
}
// GetPromptResponse 获取提示词响应
type GetPromptResponse struct {
Messages []PromptMessage `json:"messages"`
}
// PromptMessage 提示词消息
type PromptMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// Resource 资源
type Resource struct {
URI string `json:"uri"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
MimeType string `json:"mimeType,omitempty"`
}
// ReadResourceRequest 读取资源请求
type ReadResourceRequest struct {
URI string `json:"uri"`
}
// ReadResourceResponse 读取资源响应
type ReadResourceResponse struct {
Contents []ResourceContent `json:"contents"`
}
// ResourceContent 资源内容
type ResourceContent struct {
URI string `json:"uri"`
MimeType string `json:"mimeType,omitempty"`
Text string `json:"text,omitempty"`
Blob string `json:"blob,omitempty"`
}
// SamplingRequest 采样请求
type SamplingRequest struct {
Messages []SamplingMessage `json:"messages"`
Model string `json:"model,omitempty"`
MaxTokens int `json:"maxTokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
}
// SamplingMessage 采样消息
type SamplingMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// SamplingResponse 采样响应
type SamplingResponse struct {
Content []SamplingContent `json:"content"`
Model string `json:"model,omitempty"`
StopReason string `json:"stopReason,omitempty"`
}
// SamplingContent 采样内容
type SamplingContent struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
}

View File

@@ -0,0 +1,730 @@
package security
import (
"context"
"fmt"
"os/exec"
"strings"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"go.uber.org/zap"
)
// Executor 安全工具执行器
type Executor struct {
config *config.SecurityConfig
mcpServer *mcp.Server
logger *zap.Logger
}
// NewExecutor 创建新的执行器
func NewExecutor(cfg *config.SecurityConfig, mcpServer *mcp.Server, logger *zap.Logger) *Executor {
return &Executor{
config: cfg,
mcpServer: mcpServer,
logger: logger,
}
}
// ExecuteTool 执行安全工具
func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[string]interface{}) (*mcp.ToolResult, error) {
e.logger.Info("ExecuteTool被调用",
zap.String("toolName", toolName),
zap.Any("args", args),
)
// 特殊处理exec工具直接执行系统命令
if toolName == "exec" {
e.logger.Info("执行exec工具")
return e.executeSystemCommand(ctx, args)
}
// 查找工具配置
var toolConfig *config.ToolConfig
for i := range e.config.Tools {
if e.config.Tools[i].Name == toolName && e.config.Tools[i].Enabled {
toolConfig = &e.config.Tools[i]
break
}
}
if toolConfig == nil {
e.logger.Error("工具未找到或未启用",
zap.String("toolName", toolName),
zap.Int("totalTools", len(e.config.Tools)),
)
return nil, fmt.Errorf("工具 %s 未找到或未启用", toolName)
}
e.logger.Info("找到工具配置",
zap.String("toolName", toolName),
zap.String("command", toolConfig.Command),
zap.Strings("args", toolConfig.Args),
)
// 构建命令 - 根据工具类型使用不同的参数格式
cmdArgs := e.buildCommandArgs(toolName, toolConfig, args)
e.logger.Info("构建命令参数完成",
zap.String("toolName", toolName),
zap.Strings("cmdArgs", cmdArgs),
zap.Int("argsCount", len(cmdArgs)),
)
// 验证命令参数
if len(cmdArgs) == 0 {
e.logger.Warn("命令参数为空",
zap.String("toolName", toolName),
zap.Any("inputArgs", args),
)
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("错误: 工具 %s 缺少必需的参数。接收到的参数: %v", toolName, args),
},
},
IsError: true,
}, nil
}
// 执行命令
cmd := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...)
e.logger.Info("执行安全工具",
zap.String("tool", toolName),
zap.Strings("args", cmdArgs),
)
output, err := cmd.CombinedOutput()
if err != nil {
e.logger.Error("工具执行失败",
zap.String("tool", toolName),
zap.Error(err),
zap.String("output", string(output)),
)
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("工具执行失败: %v\n输出: %s", err, string(output)),
},
},
IsError: true,
}, nil
}
e.logger.Info("工具执行成功",
zap.String("tool", toolName),
zap.String("output", string(output)),
)
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: string(output),
},
},
IsError: false,
}, nil
}
// RegisterTools 注册工具到MCP服务器
func (e *Executor) RegisterTools(mcpServer *mcp.Server) {
e.logger.Info("开始注册工具",
zap.Int("totalTools", len(e.config.Tools)),
)
for i, toolConfig := range e.config.Tools {
if !toolConfig.Enabled {
e.logger.Debug("跳过未启用的工具",
zap.String("tool", toolConfig.Name),
)
continue
}
// 创建工具配置的副本,避免闭包问题
toolName := toolConfig.Name
toolConfigCopy := toolConfig
tool := mcp.Tool{
Name: toolConfigCopy.Name,
Description: toolConfigCopy.Description,
InputSchema: e.buildInputSchema(&toolConfigCopy),
}
handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
e.logger.Info("工具handler被调用",
zap.String("toolName", toolName),
zap.Any("args", args),
)
return e.ExecuteTool(ctx, toolName, args)
}
mcpServer.RegisterTool(tool, handler)
e.logger.Info("注册安全工具成功",
zap.String("tool", toolConfigCopy.Name),
zap.String("command", toolConfigCopy.Command),
zap.Int("index", i),
)
}
e.logger.Info("工具注册完成",
zap.Int("registeredCount", len(e.config.Tools)),
)
}
// buildCommandArgs 构建命令参数
func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConfig, args map[string]interface{}) []string {
cmdArgs := make([]string, 0)
// 如果配置中定义了参数映射,使用配置中的映射规则
if len(toolConfig.Parameters) > 0 {
// 先添加固定参数
cmdArgs = append(cmdArgs, toolConfig.Args...)
// 按位置参数排序
positionalParams := make([]config.ParameterConfig, 0)
flagParams := make([]config.ParameterConfig, 0)
for _, param := range toolConfig.Parameters {
if param.Position != nil {
positionalParams = append(positionalParams, param)
} else {
flagParams = append(flagParams, param)
}
}
// 对位置参数按位置排序
for i := 0; i < len(positionalParams); i++ {
for _, param := range positionalParams {
if param.Position != nil && *param.Position == i {
value := e.getParamValue(args, param)
if value == nil {
if param.Required {
// 必需参数缺失,返回空数组让上层处理错误
e.logger.Warn("缺少必需的位置参数",
zap.String("tool", toolName),
zap.String("param", param.Name),
zap.Int("position", *param.Position),
)
return []string{}
}
break
}
cmdArgs = append(cmdArgs, e.formatParamValue(param, value))
break
}
}
}
// 处理标志参数
for _, param := range flagParams {
value := e.getParamValue(args, param)
if value == nil {
if param.Required {
// 必需参数缺失,返回空数组让上层处理错误
e.logger.Warn("缺少必需的标志参数",
zap.String("tool", toolName),
zap.String("param", param.Name),
)
return []string{}
}
continue
}
// 布尔值特殊处理:如果为 false跳过如果为 true只添加标志
if param.Type == "bool" {
if boolVal, ok := value.(bool); ok {
if !boolVal {
continue // false 时不添加任何参数
}
// true 时只添加标志,不添加值
if param.Flag != "" {
cmdArgs = append(cmdArgs, param.Flag)
}
continue
}
}
format := param.Format
if format == "" {
format = "flag" // 默认格式
}
switch format {
case "flag":
// --flag value 或 -f value
if param.Flag != "" {
cmdArgs = append(cmdArgs, param.Flag)
}
formattedValue := e.formatParamValue(param, value)
if formattedValue != "" {
cmdArgs = append(cmdArgs, formattedValue)
}
case "combined":
// --flag=value 或 -f=value
if param.Flag != "" {
cmdArgs = append(cmdArgs, fmt.Sprintf("%s=%s", param.Flag, e.formatParamValue(param, value)))
} else {
cmdArgs = append(cmdArgs, e.formatParamValue(param, value))
}
case "template":
// 使用模板字符串
if param.Template != "" {
template := param.Template
template = strings.ReplaceAll(template, "{flag}", param.Flag)
template = strings.ReplaceAll(template, "{value}", e.formatParamValue(param, value))
template = strings.ReplaceAll(template, "{name}", param.Name)
cmdArgs = append(cmdArgs, strings.Fields(template)...)
} else {
// 如果没有模板,使用默认格式
if param.Flag != "" {
cmdArgs = append(cmdArgs, param.Flag)
}
cmdArgs = append(cmdArgs, e.formatParamValue(param, value))
}
case "positional":
// 位置参数(已在上面处理)
cmdArgs = append(cmdArgs, e.formatParamValue(param, value))
default:
// 默认:直接添加值
cmdArgs = append(cmdArgs, e.formatParamValue(param, value))
}
}
return cmdArgs
}
// 向后兼容:如果没有定义参数,使用旧的硬编码逻辑
switch toolName {
case "nmap":
// nmap -sT -sV -sC target [ports]
// 使用 -sT (TCP连接扫描) 而不是 -sS (SYN扫描),因为 -sS 需要root权限
e.logger.Debug("处理nmap参数",
zap.Any("args", args),
)
// 尝试多种方式获取target参数
var target string
var ok bool
// 方式1: 直接获取target
if target, ok = args["target"].(string); !ok || target == "" {
// 方式2: 尝试从tool字段获取兼容某些格式
if toolVal, exists := args["tool"]; exists {
if toolMap, ok := toolVal.(map[string]interface{}); ok {
if t, ok := toolMap["target"].(string); ok {
target = t
}
}
}
}
if target == "" {
e.logger.Warn("nmap缺少target参数",
zap.Any("args", args),
)
return cmdArgs // 返回空数组,让上层处理错误
}
e.logger.Debug("提取到target",
zap.String("target", target),
)
// 处理URL格式的目标提取域名
if strings.HasPrefix(target, "http://") || strings.HasPrefix(target, "https://") {
// 提取域名部分
target = strings.TrimPrefix(target, "http://")
target = strings.TrimPrefix(target, "https://")
// 移除路径部分
if idx := strings.Index(target, "/"); idx != -1 {
target = target[:idx]
}
}
// 添加扫描选项:-sT (TCP连接扫描不需要root权限), -sV (版本检测), -sC (默认脚本)
cmdArgs = append(cmdArgs, "-sT", "-sV", "-sC")
// 添加端口范围(如果指定)
if ports, ok := args["ports"].(string); ok && ports != "" {
cmdArgs = append(cmdArgs, "-p", ports)
}
// 添加目标
cmdArgs = append(cmdArgs, target)
e.logger.Debug("nmap命令参数构建完成",
zap.Strings("cmdArgs", cmdArgs),
)
case "sqlmap":
// sqlmap -u url
if url, ok := args["url"].(string); ok {
cmdArgs = append(cmdArgs, "-u", url, "--batch", "--level=3", "--risk=2")
}
case "nikto":
// nikto -h target
if target, ok := args["target"].(string); ok {
cmdArgs = append(cmdArgs, "-h", target)
}
case "dirb":
// dirb url
if url, ok := args["url"].(string); ok {
cmdArgs = append(cmdArgs, url)
}
default:
// 通用处理
cmdArgs = append(cmdArgs, toolConfig.Args...)
for key, value := range args {
if key == "_tool_name" {
continue
}
cmdArgs = append(cmdArgs, fmt.Sprintf("--%s", key))
if strValue, ok := value.(string); ok {
cmdArgs = append(cmdArgs, strValue)
} else {
cmdArgs = append(cmdArgs, fmt.Sprintf("%v", value))
}
}
}
return cmdArgs
}
// getParamValue 获取参数值,支持默认值
func (e *Executor) getParamValue(args map[string]interface{}, param config.ParameterConfig) interface{} {
// 从参数中获取值
if value, ok := args[param.Name]; ok && value != nil {
return value
}
// 如果参数是必需的但没有提供,返回 nil让上层处理错误
if param.Required {
return nil
}
// 返回默认值
return param.Default
}
// formatParamValue 格式化参数值
func (e *Executor) formatParamValue(param config.ParameterConfig, value interface{}) string {
switch param.Type {
case "bool":
// 布尔值应该在上层处理,这里不应该被调用
if boolVal, ok := value.(bool); ok {
return fmt.Sprintf("%v", boolVal)
}
return "false"
case "array":
// 数组:转换为逗号分隔的字符串
if arr, ok := value.([]interface{}); ok {
strs := make([]string, 0, len(arr))
for _, item := range arr {
strs = append(strs, fmt.Sprintf("%v", item))
}
return strings.Join(strs, ",")
}
return fmt.Sprintf("%v", value)
default:
return fmt.Sprintf("%v", value)
}
}
// executeSystemCommand 执行系统命令
func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
// 获取命令
command, ok := args["command"].(string)
if !ok {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "错误: 缺少command参数",
},
},
IsError: true,
}, nil
}
if command == "" {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "错误: command参数不能为空",
},
},
IsError: true,
}, nil
}
// 安全检查:记录执行的命令
e.logger.Warn("执行系统命令",
zap.String("command", command),
)
// 获取shell类型可选默认为sh
shell := "sh"
if s, ok := args["shell"].(string); ok && s != "" {
shell = s
}
// 获取工作目录(可选)
workDir := ""
if wd, ok := args["workdir"].(string); ok && wd != "" {
workDir = wd
}
// 构建命令
var cmd *exec.Cmd
if workDir != "" {
cmd = exec.CommandContext(ctx, shell, "-c", command)
cmd.Dir = workDir
} else {
cmd = exec.CommandContext(ctx, shell, "-c", command)
}
// 执行命令
e.logger.Info("执行系统命令",
zap.String("command", command),
zap.String("shell", shell),
zap.String("workdir", workDir),
)
output, err := cmd.CombinedOutput()
if err != nil {
e.logger.Error("系统命令执行失败",
zap.String("command", command),
zap.Error(err),
zap.String("output", string(output)),
)
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("命令执行失败: %v\n输出: %s", err, string(output)),
},
},
IsError: true,
}, nil
}
e.logger.Info("系统命令执行成功",
zap.String("command", command),
zap.String("output_length", fmt.Sprintf("%d", len(output))),
)
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: string(output),
},
},
IsError: false,
}, nil
}
// buildInputSchema 构建输入模式
func (e *Executor) buildInputSchema(toolConfig *config.ToolConfig) map[string]interface{} {
schema := map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{},
"required": []string{},
}
// 如果配置中定义了参数,优先使用配置中的参数定义
if len(toolConfig.Parameters) > 0 {
properties := make(map[string]interface{})
required := []string{}
for _, param := range toolConfig.Parameters {
prop := map[string]interface{}{
"type": param.Type,
"description": param.Description,
}
// 添加默认值
if param.Default != nil {
prop["default"] = param.Default
}
// 添加枚举选项
if len(param.Options) > 0 {
prop["enum"] = param.Options
}
properties[param.Name] = prop
// 添加到必需参数列表
if param.Required {
required = append(required, param.Name)
}
}
schema["properties"] = properties
schema["required"] = required
return schema
}
// 向后兼容:如果没有定义参数,使用旧的硬编码逻辑
switch toolConfig.Name {
case "nmap":
schema["properties"] = map[string]interface{}{
"target": map[string]interface{}{
"type": "string",
"description": "目标IP地址或域名",
},
"ports": map[string]interface{}{
"type": "string",
"description": "端口范围,例如: 1-1000",
},
}
schema["required"] = []string{"target"}
case "sqlmap":
schema["properties"] = map[string]interface{}{
"url": map[string]interface{}{
"type": "string",
"description": "目标URL",
},
}
schema["required"] = []string{"url"}
case "nikto", "dirb":
schema["properties"] = map[string]interface{}{
"target": map[string]interface{}{
"type": "string",
"description": "目标URL",
},
}
schema["required"] = []string{"target"}
case "exec":
schema["properties"] = map[string]interface{}{
"command": map[string]interface{}{
"type": "string",
"description": "要执行的系统命令",
},
"shell": map[string]interface{}{
"type": "string",
"description": "使用的shell可选默认为sh",
},
"workdir": map[string]interface{}{
"type": "string",
"description": "工作目录(可选)",
},
}
schema["required"] = []string{"command"}
}
return schema
}
// Vulnerability 漏洞信息
type Vulnerability struct {
ID string `json:"id"`
Type string `json:"type"`
Severity string `json:"severity"` // low, medium, high, critical
Title string `json:"title"`
Description string `json:"description"`
Target string `json:"target"`
FoundAt time.Time `json:"foundAt"`
Details string `json:"details"`
}
// AnalyzeResults 分析工具执行结果,提取漏洞信息
func (e *Executor) AnalyzeResults(toolName string, result *mcp.ToolResult) []Vulnerability {
vulnerabilities := []Vulnerability{}
if result.IsError {
return vulnerabilities
}
// 分析输出内容
for _, content := range result.Content {
if content.Type == "text" {
vulns := e.parseToolOutput(toolName, content.Text)
vulnerabilities = append(vulnerabilities, vulns...)
}
}
return vulnerabilities
}
// parseToolOutput 解析工具输出
func (e *Executor) parseToolOutput(toolName, output string) []Vulnerability {
vulnerabilities := []Vulnerability{}
// 简单的漏洞检测逻辑
outputLower := strings.ToLower(output)
// SQL注入检测
if strings.Contains(outputLower, "sql injection") || strings.Contains(outputLower, "sqli") {
vulnerabilities = append(vulnerabilities, Vulnerability{
ID: fmt.Sprintf("sql-%d", time.Now().Unix()),
Type: "SQL Injection",
Severity: "high",
Title: "SQL注入漏洞",
Description: "检测到潜在的SQL注入漏洞",
FoundAt: time.Now(),
Details: output,
})
}
// XSS检测
if strings.Contains(outputLower, "xss") || strings.Contains(outputLower, "cross-site scripting") {
vulnerabilities = append(vulnerabilities, Vulnerability{
ID: fmt.Sprintf("xss-%d", time.Now().Unix()),
Type: "XSS",
Severity: "medium",
Title: "跨站脚本攻击漏洞",
Description: "检测到潜在的XSS漏洞",
FoundAt: time.Now(),
Details: output,
})
}
// 开放端口检测
if toolName == "nmap" {
lines := strings.Split(output, "\n")
for _, line := range lines {
if strings.Contains(line, "open") && strings.Contains(line, "port") {
vulnerabilities = append(vulnerabilities, Vulnerability{
ID: fmt.Sprintf("port-%d", time.Now().Unix()),
Type: "Open Port",
Severity: "low",
Title: "开放端口",
Description: fmt.Sprintf("发现开放端口: %s", line),
FoundAt: time.Now(),
Details: line,
})
}
}
}
return vulnerabilities
}
// GetVulnerabilityReport 生成漏洞报告
func (e *Executor) GetVulnerabilityReport(vulnerabilities []Vulnerability) map[string]interface{} {
severityCount := map[string]int{
"critical": 0,
"high": 0,
"medium": 0,
"low": 0,
}
for _, vuln := range vulnerabilities {
severityCount[vuln.Severity]++
}
return map[string]interface{}{
"total": len(vulnerabilities),
"severityCount": severityCount,
"vulnerabilities": vulnerabilities,
"generatedAt": time.Now(),
}
}

35
run.sh Normal file
View File

@@ -0,0 +1,35 @@
#!/bin/bash
# CyberStrikeAI 启动脚本
echo "🚀 启动 CyberStrikeAI..."
# 检查配置文件
if [ ! -f "config.yaml" ]; then
echo "❌ 配置文件 config.yaml 不存在"
exit 1
fi
# 检查Go环境
if ! command -v go &> /dev/null; then
echo "❌ Go 未安装,请先安装 Go 1.21 或更高版本"
exit 1
fi
# 下载依赖
echo "📦 下载依赖..."
go mod download
# 构建项目
echo "🔨 构建项目..."
go build -o cyberstrike-ai cmd/server/main.go
if [ $? -ne 0 ]; then
echo "❌ 构建失败"
exit 1
fi
# 运行服务器
echo "✅ 启动服务器..."
./cyberstrike-ai

691
web/static/css/style.css Normal file
View File

@@ -0,0 +1,691 @@
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
:root {
--primary-color: #1a1a1a;
--secondary-color: #2d2d2d;
--accent-color: #0066ff;
--accent-hover: #0052cc;
--bg-primary: #ffffff;
--bg-secondary: #f8f9fa;
--bg-tertiary: #f1f3f5;
--text-primary: #1a1a1a;
--text-secondary: #6c757d;
--text-muted: #adb5bd;
--border-color: #e9ecef;
--success-color: #28a745;
--warning-color: #ffc107;
--error-color: #dc3545;
--shadow-sm: 0 1px 3px rgba(0, 0, 0, 0.05);
--shadow-md: 0 4px 6px rgba(0, 0, 0, 0.1);
--shadow-lg: 0 10px 25px rgba(0, 0, 0, 0.15);
}
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Helvetica Neue', Arial, sans-serif;
background: var(--bg-secondary);
margin: 0;
padding: 0;
color: var(--text-primary);
line-height: 1.6;
height: 100vh;
overflow: hidden;
}
.container {
max-width: 100%;
margin: 0;
background: var(--bg-primary);
height: 100vh;
display: flex;
flex-direction: column;
box-shadow: var(--shadow-lg);
overflow: hidden;
}
.main-layout {
display: flex;
flex: 1;
overflow: hidden;
min-height: 0;
}
header {
background: var(--primary-color);
color: white;
padding: 24px 32px;
border-bottom: 1px solid rgba(255, 255, 255, 0.1);
flex-shrink: 0;
}
.header-content {
display: flex;
justify-content: space-between;
align-items: center;
}
.logo {
display: flex;
align-items: center;
gap: 12px;
}
.logo svg {
color: var(--accent-color);
}
.logo h1 {
font-size: 1.75rem;
font-weight: 600;
letter-spacing: -0.5px;
margin: 0;
}
.header-subtitle {
font-size: 0.875rem;
color: rgba(255, 255, 255, 0.7);
margin: 0;
font-weight: 400;
}
/* 侧边栏样式 */
.sidebar {
width: 280px;
background: var(--bg-secondary);
border-right: 1px solid var(--border-color);
display: flex;
flex-direction: column;
flex-shrink: 0;
height: 100%;
overflow: hidden;
}
.sidebar-header {
padding: 16px;
border-bottom: 1px solid var(--border-color);
flex-shrink: 0;
}
.new-chat-btn {
width: 100%;
padding: 10px 16px;
background: var(--accent-color);
color: white;
border: none;
border-radius: 8px;
font-size: 0.9375rem;
font-weight: 500;
cursor: pointer;
transition: all 0.2s;
display: flex;
align-items: center;
justify-content: center;
gap: 8px;
}
.new-chat-btn:hover {
background: var(--accent-hover);
transform: translateY(-1px);
box-shadow: var(--shadow-sm);
}
.new-chat-btn span {
font-size: 1.2em;
line-height: 1;
}
.sidebar-content {
flex: 1;
overflow-y: auto;
overflow-x: hidden;
padding: 16px;
min-height: 0;
}
.sidebar-title {
font-size: 0.8125rem;
font-weight: 600;
color: var(--text-secondary);
text-transform: uppercase;
letter-spacing: 0.5px;
margin-bottom: 12px;
padding: 0 8px;
}
.conversations-list {
display: flex;
flex-direction: column;
gap: 4px;
}
.conversation-item {
padding: 12px;
border-radius: 8px;
cursor: pointer;
transition: all 0.2s;
border: 1px solid transparent;
}
.conversation-item:hover {
background: var(--bg-tertiary);
}
.conversation-item.active {
background: var(--bg-primary);
border-color: var(--accent-color);
}
.conversation-title {
font-size: 0.875rem;
font-weight: 500;
color: var(--text-primary);
margin-bottom: 4px;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
.conversation-time {
font-size: 0.75rem;
color: var(--text-muted);
}
/* 对话界面样式 */
.chat-container {
display: flex;
flex-direction: column;
flex: 1;
min-width: 0;
background: var(--bg-primary);
overflow: hidden;
height: 100%;
}
.chat-messages {
flex: 1;
overflow-y: auto;
overflow-x: hidden;
padding: 24px;
background: var(--bg-secondary);
display: flex;
flex-direction: column;
min-height: 0;
}
.message {
margin-bottom: 24px;
display: flex;
align-items: flex-start;
gap: 12px;
animation: fadeIn 0.3s ease-in;
width: 100%;
}
@keyframes fadeIn {
from {
opacity: 0;
transform: translateY(10px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
.message.user {
flex-direction: row-reverse;
justify-content: flex-end;
}
.message.system {
justify-content: center;
margin-bottom: 16px;
}
.message-avatar {
width: 32px;
height: 32px;
border-radius: 6px;
display: flex;
align-items: center;
justify-content: center;
font-size: 0.75rem;
font-weight: 600;
flex-shrink: 0;
margin-top: 2px;
}
.message.user .message-avatar {
background: var(--accent-color);
color: white;
}
.message.assistant .message-avatar {
background: var(--bg-tertiary);
color: var(--text-secondary);
border: 1px solid var(--border-color);
}
.message.system .message-avatar {
display: none;
}
.message-content {
flex: 0 1 auto;
max-width: 70%;
min-width: 120px;
display: flex;
flex-direction: column;
}
.message.user .message-content {
align-items: flex-end;
margin-left: auto;
}
.message.assistant .message-content {
align-items: flex-start;
margin-right: auto;
}
.message.system .message-content {
max-width: 90%;
align-items: center;
margin: 0 auto;
}
.message-bubble {
padding: 12px 16px;
border-radius: 8px;
word-wrap: break-word;
word-break: break-word;
line-height: 1.6;
box-shadow: var(--shadow-sm);
white-space: pre-wrap;
}
.message.user .message-bubble {
background: var(--accent-color);
color: white;
border-bottom-right-radius: 2px;
}
.message.assistant .message-bubble {
background: var(--bg-primary);
color: var(--text-primary);
border: 1px solid var(--border-color);
border-bottom-left-radius: 2px;
}
.message.assistant .message-bubble pre {
margin: 0;
white-space: pre-wrap;
word-wrap: break-word;
font-family: inherit;
}
.message.system .message-bubble {
background: var(--bg-tertiary);
color: var(--text-secondary);
border: 1px solid var(--border-color);
text-align: center;
font-size: 0.875rem;
padding: 10px 16px;
width: 100%;
}
.message-time {
font-size: 0.6875rem;
color: var(--text-muted);
margin-top: 4px;
padding: 0 2px;
font-weight: 400;
}
/* MCP调用区域 */
.mcp-call-section {
margin-top: 12px;
padding-top: 12px;
border-top: 1px solid var(--border-color);
width: 100%;
}
.mcp-call-label {
font-size: 0.75rem;
color: var(--text-secondary);
margin-bottom: 8px;
display: flex;
align-items: center;
gap: 6px;
font-weight: 500;
}
.mcp-call-label::before {
content: '';
width: 4px;
height: 4px;
background: var(--accent-color);
border-radius: 50%;
display: inline-block;
flex-shrink: 0;
}
.mcp-call-buttons {
display: flex;
flex-wrap: wrap;
gap: 6px;
}
.chat-input-container {
display: flex;
gap: 12px;
padding: 20px 24px;
background: var(--bg-primary);
border-top: 1px solid var(--border-color);
flex-shrink: 0;
}
.chat-input-container input {
flex: 1;
padding: 12px 16px;
border: 1px solid var(--border-color);
border-radius: 8px;
font-size: 0.9375rem;
outline: none;
transition: all 0.2s;
background: var(--bg-primary);
color: var(--text-primary);
}
.chat-input-container input:focus {
border-color: var(--accent-color);
box-shadow: 0 0 0 3px rgba(0, 102, 255, 0.1);
}
.chat-input-container input::placeholder {
color: var(--text-muted);
}
.chat-input-container button {
padding: 12px 24px;
background: var(--accent-color);
color: white;
border: none;
border-radius: 8px;
cursor: pointer;
font-size: 0.9375rem;
font-weight: 500;
transition: all 0.2s;
white-space: nowrap;
}
.chat-input-container button:hover {
background: var(--accent-hover);
transform: translateY(-1px);
box-shadow: var(--shadow-md);
}
.chat-input-container button:active {
transform: translateY(0);
}
/* MCP调用详情按钮 */
.mcp-detail-btn {
display: inline-flex;
align-items: center;
gap: 6px;
padding: 6px 12px;
background: var(--bg-primary);
color: var(--accent-color);
border: 1px solid var(--border-color);
border-radius: 6px;
font-size: 0.8125rem;
font-weight: 500;
cursor: pointer;
transition: all 0.2s;
}
.mcp-detail-btn:hover {
background: var(--accent-color);
color: white;
border-color: var(--accent-color);
transform: translateY(-1px);
box-shadow: var(--shadow-sm);
}
.mcp-detail-btn:active {
transform: translateY(0);
}
/* 模态框样式 */
.modal {
display: none;
position: fixed;
z-index: 1000;
left: 0;
top: 0;
width: 100%;
height: 100%;
background-color: rgba(0, 0, 0, 0.6);
backdrop-filter: blur(4px);
overflow: auto;
animation: fadeIn 0.2s ease-in;
}
.modal-content {
background-color: var(--bg-primary);
margin: 5% auto;
padding: 0;
border-radius: 12px;
width: 90%;
max-width: 900px;
max-height: 85vh;
display: flex;
flex-direction: column;
box-shadow: var(--shadow-lg);
border: 1px solid var(--border-color);
animation: slideDown 0.3s ease-out;
}
@keyframes slideDown {
from {
opacity: 0;
transform: translateY(-20px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
.modal-header {
display: flex;
justify-content: space-between;
align-items: center;
padding: 20px 24px;
border-bottom: 1px solid var(--border-color);
background: var(--bg-primary);
}
.modal-header h2 {
margin: 0;
font-size: 1.25rem;
font-weight: 600;
color: var(--text-primary);
}
.modal-close {
width: 32px;
height: 32px;
display: flex;
align-items: center;
justify-content: center;
border-radius: 6px;
cursor: pointer;
color: var(--text-secondary);
font-size: 1.5rem;
line-height: 1;
transition: all 0.2s;
border: none;
background: transparent;
}
.modal-close:hover {
background: var(--bg-tertiary);
color: var(--text-primary);
}
.modal-body {
padding: 24px;
overflow-y: auto;
flex: 1;
}
.detail-section {
margin-bottom: 24px;
}
.detail-section:last-child {
margin-bottom: 0;
}
.detail-section h3 {
color: var(--text-primary);
margin-bottom: 12px;
padding-bottom: 8px;
border-bottom: 1px solid var(--border-color);
font-size: 0.9375rem;
font-weight: 600;
text-transform: uppercase;
letter-spacing: 0.5px;
}
.detail-item {
margin-bottom: 10px;
padding: 8px 0;
display: flex;
align-items: baseline;
gap: 8px;
}
.detail-item strong {
color: var(--text-secondary);
font-weight: 500;
font-size: 0.875rem;
min-width: 80px;
}
.detail-item span {
color: var(--text-primary);
font-size: 0.875rem;
}
.code-block {
background: var(--bg-tertiary);
border: 1px solid var(--border-color);
border-radius: 8px;
padding: 16px;
font-family: 'SF Mono', 'Monaco', 'Menlo', 'Ubuntu Mono', 'Consolas', monospace;
font-size: 0.8125rem;
line-height: 1.6;
overflow-x: auto;
white-space: pre-wrap;
word-wrap: break-word;
max-height: 400px;
overflow-y: auto;
color: var(--text-primary);
}
.code-block.error {
background: #fff5f5;
border-color: var(--error-color);
color: var(--error-color);
}
/* 滚动条样式 */
::-webkit-scrollbar {
width: 8px;
height: 8px;
}
::-webkit-scrollbar-track {
background: var(--bg-secondary);
border-radius: 4px;
}
::-webkit-scrollbar-thumb {
background: var(--text-muted);
border-radius: 4px;
}
::-webkit-scrollbar-thumb:hover {
background: var(--text-secondary);
}
/* 响应式设计 */
@media (max-width: 768px) {
body {
height: 100vh;
overflow: hidden;
}
.container {
border-radius: 0;
height: 100vh;
overflow: hidden;
}
header {
padding: 16px 20px;
flex-shrink: 0;
}
.logo h1 {
font-size: 1.5rem;
}
.header-subtitle {
display: none;
}
.main-layout {
overflow: hidden;
min-height: 0;
}
.sidebar {
height: 100%;
overflow: hidden;
}
.sidebar-content {
min-height: 0;
}
.chat-container {
height: 100%;
overflow: hidden;
}
.chat-messages {
padding: 16px;
min-height: 0;
}
.message-content {
max-width: 85%;
}
.chat-input-container {
padding: 16px;
flex-shrink: 0;
}
.modal-content {
width: 95%;
margin: 10% auto;
}
}

400
web/static/js/app.js Normal file
View File

@@ -0,0 +1,400 @@
// 当前对话ID
let currentConversationId = null;
// 发送消息
async function sendMessage() {
const input = document.getElementById('chat-input');
const message = input.value.trim();
if (!message) {
return;
}
// 显示用户消息
addMessage('user', message);
input.value = '';
// 显示加载状态
const loadingId = addMessage('system', '正在处理中...');
try {
const response = await fetch('/api/agent-loop', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
message: message,
conversationId: currentConversationId
}),
});
const data = await response.json();
// 移除加载消息
removeMessage(loadingId);
if (response.ok) {
// 更新当前对话ID
if (data.conversationId) {
currentConversationId = data.conversationId;
updateActiveConversation();
}
// 如果有MCP执行ID显示所有调用
const mcpIds = data.mcpExecutionIds || [];
addMessage('assistant', data.response, mcpIds);
// 刷新对话列表
loadConversations();
} else {
addMessage('system', '错误: ' + (data.error || '未知错误'));
}
} catch (error) {
removeMessage(loadingId);
addMessage('system', '错误: ' + error.message);
}
}
// 消息计数器确保ID唯一
let messageCounter = 0;
// 添加消息
function addMessage(role, content, mcpExecutionIds = null) {
const messagesDiv = document.getElementById('chat-messages');
const messageDiv = document.createElement('div');
messageCounter++;
const id = 'msg-' + Date.now() + '-' + messageCounter + '-' + Math.random().toString(36).substr(2, 9);
messageDiv.id = id;
messageDiv.className = 'message ' + role;
// 创建头像
const avatar = document.createElement('div');
avatar.className = 'message-avatar';
if (role === 'user') {
avatar.textContent = 'U';
} else if (role === 'assistant') {
avatar.textContent = 'A';
} else {
avatar.textContent = 'S';
}
messageDiv.appendChild(avatar);
// 创建消息内容容器
const contentWrapper = document.createElement('div');
contentWrapper.className = 'message-content';
// 创建消息气泡
const bubble = document.createElement('div');
bubble.className = 'message-bubble';
// 处理换行和格式化
const formattedContent = content.replace(/\n/g, '<br>');
bubble.innerHTML = formattedContent;
contentWrapper.appendChild(bubble);
// 添加时间戳
const timeDiv = document.createElement('div');
timeDiv.className = 'message-time';
timeDiv.textContent = new Date().toLocaleTimeString('zh-CN', { hour: '2-digit', minute: '2-digit' });
contentWrapper.appendChild(timeDiv);
// 如果有MCP执行ID添加查看详情区域
if (mcpExecutionIds && Array.isArray(mcpExecutionIds) && mcpExecutionIds.length > 0 && role === 'assistant') {
const mcpSection = document.createElement('div');
mcpSection.className = 'mcp-call-section';
const mcpLabel = document.createElement('div');
mcpLabel.className = 'mcp-call-label';
mcpLabel.textContent = `工具调用 (${mcpExecutionIds.length})`;
mcpSection.appendChild(mcpLabel);
const buttonsContainer = document.createElement('div');
buttonsContainer.className = 'mcp-call-buttons';
mcpExecutionIds.forEach((execId, index) => {
const detailBtn = document.createElement('button');
detailBtn.className = 'mcp-detail-btn';
detailBtn.innerHTML = `<span>调用 #${index + 1}</span>`;
detailBtn.onclick = () => showMCPDetail(execId);
buttonsContainer.appendChild(detailBtn);
});
mcpSection.appendChild(buttonsContainer);
contentWrapper.appendChild(mcpSection);
}
messageDiv.appendChild(contentWrapper);
messagesDiv.appendChild(messageDiv);
messagesDiv.scrollTop = messagesDiv.scrollHeight;
return id;
}
// 移除消息
function removeMessage(id) {
const messageDiv = document.getElementById(id);
if (messageDiv) {
messageDiv.remove();
}
}
// 回车发送消息
document.getElementById('chat-input').addEventListener('keypress', function(e) {
if (e.key === 'Enter') {
sendMessage();
}
});
// 显示MCP调用详情
async function showMCPDetail(executionId) {
try {
const response = await fetch(`/api/monitor/execution/${executionId}`);
const exec = await response.json();
if (response.ok) {
// 填充模态框内容
document.getElementById('detail-tool-name').textContent = exec.toolName || 'Unknown';
document.getElementById('detail-execution-id').textContent = exec.id || 'N/A';
document.getElementById('detail-status').textContent = getStatusText(exec.status);
document.getElementById('detail-time').textContent = new Date(exec.startTime).toLocaleString('zh-CN');
// 请求参数
const requestData = {
tool: exec.toolName,
arguments: exec.arguments
};
document.getElementById('detail-request').textContent = JSON.stringify(requestData, null, 2);
// 响应结果
if (exec.result) {
const responseData = {
content: exec.result.content,
isError: exec.result.isError
};
document.getElementById('detail-response').textContent = JSON.stringify(responseData, null, 2);
document.getElementById('detail-response').className = exec.result.isError ? 'code-block error' : 'code-block';
} else {
document.getElementById('detail-response').textContent = '暂无响应数据';
}
// 错误信息
if (exec.error) {
document.getElementById('detail-error-section').style.display = 'block';
document.getElementById('detail-error').textContent = exec.error;
} else {
document.getElementById('detail-error-section').style.display = 'none';
}
// 显示模态框
document.getElementById('mcp-detail-modal').style.display = 'block';
} else {
alert('获取详情失败: ' + (exec.error || '未知错误'));
}
} catch (error) {
alert('获取详情失败: ' + error.message);
}
}
// 关闭MCP详情模态框
function closeMCPDetail() {
document.getElementById('mcp-detail-modal').style.display = 'none';
}
// 点击模态框外部关闭
window.onclick = function(event) {
const modal = document.getElementById('mcp-detail-modal');
if (event.target == modal) {
closeMCPDetail();
}
}
// 工具函数
function getStatusText(status) {
const statusMap = {
'pending': '等待中',
'running': '执行中',
'completed': '已完成',
'failed': '失败'
};
return statusMap[status] || status;
}
function formatDuration(ms) {
const seconds = Math.floor(ms / 1000);
const minutes = Math.floor(seconds / 60);
const hours = Math.floor(minutes / 60);
if (hours > 0) {
return `${hours}小时${minutes % 60}分钟`;
} else if (minutes > 0) {
return `${minutes}分钟${seconds % 60}`;
} else {
return `${seconds}`;
}
}
function escapeHtml(text) {
const div = document.createElement('div');
div.textContent = text;
return div.innerHTML;
}
// 开始新对话
function startNewConversation() {
currentConversationId = null;
document.getElementById('chat-messages').innerHTML = '';
addMessage('assistant', '系统已就绪。请输入您的测试需求,系统将自动执行相应的安全测试。');
updateActiveConversation();
}
// 加载对话列表
async function loadConversations() {
try {
const response = await fetch('/api/conversations?limit=50');
const conversations = await response.json();
const listContainer = document.getElementById('conversations-list');
listContainer.innerHTML = '';
if (conversations.length === 0) {
listContainer.innerHTML = '<div style="padding: 20px; text-align: center; color: var(--text-muted); font-size: 0.875rem;">暂无历史对话</div>';
return;
}
conversations.forEach(conv => {
const item = document.createElement('div');
item.className = 'conversation-item';
item.dataset.conversationId = conv.id;
if (conv.id === currentConversationId) {
item.classList.add('active');
}
const title = document.createElement('div');
title.className = 'conversation-title';
title.textContent = conv.title || '未命名对话';
item.appendChild(title);
const time = document.createElement('div');
time.className = 'conversation-time';
// 解析时间,支持多种格式
let dateObj;
if (conv.updatedAt) {
dateObj = new Date(conv.updatedAt);
// 检查日期是否有效
if (isNaN(dateObj.getTime())) {
// 如果解析失败,尝试其他格式
console.warn('时间解析失败:', conv.updatedAt);
dateObj = new Date();
}
} else {
dateObj = new Date();
}
// 格式化时间显示
const now = new Date();
const today = new Date(now.getFullYear(), now.getMonth(), now.getDate());
const yesterday = new Date(today);
yesterday.setDate(yesterday.getDate() - 1);
const messageDate = new Date(dateObj.getFullYear(), dateObj.getMonth(), dateObj.getDate());
let timeText;
if (messageDate.getTime() === today.getTime()) {
// 今天:只显示时间
timeText = dateObj.toLocaleTimeString('zh-CN', {
hour: '2-digit',
minute: '2-digit'
});
} else if (messageDate.getTime() === yesterday.getTime()) {
// 昨天
timeText = '昨天 ' + dateObj.toLocaleTimeString('zh-CN', {
hour: '2-digit',
minute: '2-digit'
});
} else if (now.getFullYear() === dateObj.getFullYear()) {
// 今年:显示月日和时间
timeText = dateObj.toLocaleString('zh-CN', {
month: 'short',
day: 'numeric',
hour: '2-digit',
minute: '2-digit'
});
} else {
// 去年或更早:显示完整日期和时间
timeText = dateObj.toLocaleString('zh-CN', {
year: 'numeric',
month: 'short',
day: 'numeric',
hour: '2-digit',
minute: '2-digit'
});
}
time.textContent = timeText;
item.appendChild(time);
item.onclick = () => loadConversation(conv.id);
listContainer.appendChild(item);
});
} catch (error) {
console.error('加载对话列表失败:', error);
}
}
// 加载对话
async function loadConversation(conversationId) {
try {
const response = await fetch(`/api/conversations/${conversationId}`);
const conversation = await response.json();
if (!response.ok) {
alert('加载对话失败: ' + (conversation.error || '未知错误'));
return;
}
// 更新当前对话ID
currentConversationId = conversationId;
updateActiveConversation();
// 清空消息区域
const messagesDiv = document.getElementById('chat-messages');
messagesDiv.innerHTML = '';
// 加载消息
if (conversation.messages && conversation.messages.length > 0) {
conversation.messages.forEach(msg => {
addMessage(msg.role, msg.content, msg.mcpExecutionIds || []);
});
} else {
addMessage('assistant', '系统已就绪。请输入您的测试需求,系统将自动执行相应的安全测试。');
}
// 滚动到底部
messagesDiv.scrollTop = messagesDiv.scrollHeight;
// 刷新对话列表
loadConversations();
} catch (error) {
console.error('加载对话失败:', error);
alert('加载对话失败: ' + error.message);
}
}
// 更新活动对话样式
function updateActiveConversation() {
document.querySelectorAll('.conversation-item').forEach(item => {
item.classList.remove('active');
if (currentConversationId && item.dataset.conversationId === currentConversationId) {
item.classList.add('active');
}
});
}
// 页面加载时初始化
document.addEventListener('DOMContentLoaded', function() {
// 加载对话列表
loadConversations();
// 添加欢迎消息
addMessage('assistant', '系统已就绪。请输入您的测试需求,系统将自动执行相应的安全测试。');
});

92
web/templates/index.html Normal file
View File

@@ -0,0 +1,92 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>CyberStrikeAI - 自主渗透测试平台</title>
<link rel="stylesheet" href="/static/css/style.css">
</head>
<body>
<div class="container">
<header>
<div class="header-content">
<div class="logo">
<svg width="32" height="32" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M12 2L2 7L12 12L22 7L12 2Z" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M2 17L12 22L22 17" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M2 12L12 17L22 12" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>
<h1>CyberStrike</h1>
</div>
<p class="header-subtitle">安全测试平台</p>
</div>
</header>
<div class="main-layout">
<!-- 历史对话侧边栏 -->
<aside class="sidebar">
<div class="sidebar-header">
<button class="new-chat-btn" onclick="startNewConversation()">
<span>+</span> 新对话
</button>
</div>
<div class="sidebar-content">
<div class="sidebar-title">历史对话</div>
<div id="conversations-list" class="conversations-list"></div>
</div>
</aside>
<!-- 对话界面 -->
<div class="chat-container">
<div id="chat-messages" class="chat-messages"></div>
<div class="chat-input-container">
<input type="text" id="chat-input" placeholder="输入测试目标或命令..." />
<button onclick="sendMessage()">发送</button>
</div>
</div>
</div>
</div>
<!-- MCP调用详情模态框 -->
<div id="mcp-detail-modal" class="modal">
<div class="modal-content">
<div class="modal-header">
<h2>工具调用详情</h2>
<span class="modal-close" onclick="closeMCPDetail()">&times;</span>
</div>
<div class="modal-body">
<div class="detail-section">
<h3>执行信息</h3>
<div class="detail-item">
<strong>工具:</strong> <span id="detail-tool-name"></span>
</div>
<div class="detail-item">
<strong>状态:</strong> <span id="detail-status"></span>
</div>
<div class="detail-item">
<strong>时间:</strong> <span id="detail-time"></span>
</div>
<div class="detail-item">
<strong>ID:</strong> <span id="detail-execution-id" style="font-family: monospace; font-size: 0.8125rem; color: var(--text-secondary);"></span>
</div>
</div>
<div class="detail-section">
<h3>请求参数</h3>
<pre id="detail-request" class="code-block"></pre>
</div>
<div class="detail-section">
<h3>响应结果</h3>
<pre id="detail-response" class="code-block"></pre>
</div>
<div class="detail-section" id="detail-error-section" style="display: none;">
<h3>错误信息</h3>
<pre id="detail-error" class="code-block error"></pre>
</div>
</div>
</div>
</div>
<script src="/static/js/app.js"></script>
</body>
</html>