Files
CyberStrikeAI/plugins/burp-suite/cyberstrikeai-burp-extension/src/main/java/burp/CyberStrikeAIClient.java
T
2026-06-02 13:32:05 +08:00

370 lines
14 KiB
Java

package burp;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InterruptedIOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.net.HttpURLConnection;
import java.net.SocketTimeoutException;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
final class CyberStrikeAIClient {
private static final int AUTH_CONNECT_TIMEOUT_MS = 4_000;
private static final int AUTH_READ_TIMEOUT_MS = 5_000;
/** login + validate 整段上限,避免两次读超时叠加拖到半分钟 */
private static final int AUTH_OVERALL_TIMEOUT_MS = 10_000;
private static final int DEFAULT_READ_TIMEOUT_MS = 15_000;
private final AtomicReference<HttpURLConnection> activeConnection = new AtomicReference<>();
private final AtomicReference<Thread> activeThread = new AtomicReference<>();
static final class Config {
final String baseUrl; // e.g. http://127.0.0.1:8080
final String password;
final AgentMode agentMode;
Config(String baseUrl, String password, AgentMode agentMode) {
this.baseUrl = baseUrl;
this.password = password;
this.agentMode = agentMode;
}
}
enum AgentMode {
EINO_SINGLE("Eino Single (ADK)", "/api/eino-agent/stream", null),
DEEP("Deep (DeepAgent)", "/api/multi-agent/stream", "deep"),
PLAN_EXECUTE("Plan-Execute", "/api/multi-agent/stream", "plan_execute"),
SUPERVISOR("Supervisor", "/api/multi-agent/stream", "supervisor");
final String displayName;
final String streamPath;
final String orchestration;
AgentMode(String displayName, String streamPath, String orchestration) {
this.displayName = displayName;
this.streamPath = streamPath;
this.orchestration = orchestration;
}
}
interface StreamListener {
void onEvent(String type, String message, String rawJson);
void onError(String message, Exception e);
void onDone();
}
boolean hasActiveRequest() {
return activeConnection.get() != null;
}
void cancelActiveRequest() {
HttpURLConnection conn = activeConnection.getAndSet(null);
if (conn != null) {
try {
conn.disconnect();
} catch (Exception ignored) {
}
}
Thread t = activeThread.getAndSet(null);
if (t != null) {
t.interrupt();
}
}
String loginAndValidate(Config cfg) throws IOException {
Thread worker = Thread.currentThread();
java.util.Timer deadline = new java.util.Timer("CyberStrikeAI-AuthDeadline", true);
deadline.schedule(new java.util.TimerTask() {
@Override
public void run() {
worker.interrupt();
cancelActiveRequest();
}
}, AUTH_OVERALL_TIMEOUT_MS);
try {
String token = login(cfg.baseUrl, cfg.password);
if (Thread.interrupted()) {
throw timeoutIOException();
}
validate(cfg.baseUrl, token);
if (Thread.interrupted()) {
throw timeoutIOException();
}
return token;
} catch (SocketTimeoutException e) {
throw timeoutIOException();
} finally {
deadline.cancel();
}
}
private static IOException timeoutIOException() {
return new IOException("Connection timed out (~" + (AUTH_OVERALL_TIMEOUT_MS / 1000)
+ "s). Check host/port and HTTPS checkbox.");
}
private void trackConnection(HttpURLConnection conn) {
activeThread.set(Thread.currentThread());
activeConnection.set(conn);
}
private void releaseConnection(HttpURLConnection conn) {
if (activeConnection.compareAndSet(conn, null)) {
activeThread.set(null);
}
}
private static boolean isCancelled(Throwable e) {
if (e == null) {
return Thread.currentThread().isInterrupted();
}
if (Thread.currentThread().isInterrupted()) {
return true;
}
if (e instanceof InterruptedIOException) {
return true;
}
if (e instanceof SocketTimeoutException) {
return false;
}
Throwable cause = e.getCause();
if (cause != null && cause != e) {
return isCancelled(cause);
}
String msg = e.getMessage();
return msg != null && (
msg.toLowerCase().contains("cancel")
|| msg.toLowerCase().contains("abort")
|| msg.toLowerCase().contains("closed")
);
}
private String login(String baseUrl, String password) throws IOException {
URL url = new URL(baseUrl + "/api/auth/login");
HttpURLConnection conn = SslTrustAll.open(url, AUTH_CONNECT_TIMEOUT_MS, AUTH_READ_TIMEOUT_MS);
trackConnection(conn);
try {
conn.setRequestMethod("POST");
conn.setDoOutput(true);
conn.setRequestProperty("Content-Type", "application/json");
conn.setRequestProperty("Accept", "application/json");
String body = "{\"password\":\"" + escapeJson(password) + "\"}";
try (OutputStream os = conn.getOutputStream()) {
os.write(body.getBytes(StandardCharsets.UTF_8));
}
int code = conn.getResponseCode();
String contentType = conn.getHeaderField("Content-Type");
String resp = readAll(code >= 200 && code < 300 ? conn.getInputStream() : conn.getErrorStream());
// Friendly diagnosis: HTML usually means wrong host/port (e.g., hit Burp UI/proxy page).
if (looksLikeHtml(resp) || (contentType != null && contentType.toLowerCase().contains("text/html"))) {
throw new IOException("Login failed: server returned HTML, not API JSON. Check IP/Port and ensure you point to CyberStrikeAI backend.");
}
String serverError = SimpleJson.extractStringField(resp, "error");
if (code < 200 || code >= 300) {
if (!serverError.isEmpty()) {
throw new IOException("Login failed (" + code + "): " + serverError);
}
throw new IOException("Login failed (" + code + ").");
}
if (!serverError.isEmpty()) {
throw new IOException("Login failed: " + serverError);
}
String token = SimpleJson.extractStringField(resp, "token");
if (token.isEmpty()) {
throw new IOException("Login response missing token. Check backend address and credentials.");
}
return token;
} finally {
releaseConnection(conn);
}
}
private void validate(String baseUrl, String token) throws IOException {
URL url = new URL(baseUrl + "/api/auth/validate");
HttpURLConnection conn = SslTrustAll.open(url, AUTH_CONNECT_TIMEOUT_MS, AUTH_READ_TIMEOUT_MS);
trackConnection(conn);
try {
conn.setRequestMethod("GET");
conn.setRequestProperty("Authorization", "Bearer " + token);
int code = conn.getResponseCode();
String resp = readAll(code >= 200 && code < 300 ? conn.getInputStream() : conn.getErrorStream());
if (code < 200 || code >= 300) {
throw new IOException("Validate failed (" + code + "): " + resp);
}
} finally {
releaseConnection(conn);
}
}
void streamTest(Config cfg, String token, String message, StreamListener listener) {
String urlStr = cfg.baseUrl + cfg.agentMode.streamPath;
Map<String, Object> payload = new HashMap<>();
payload.put("message", message);
payload.put("conversationId", "");
payload.put("role", "");
if (cfg.agentMode.orchestration != null) {
payload.put("orchestration", cfg.agentMode.orchestration);
}
Thread worker = new Thread(() -> {
HttpURLConnection conn = null;
try {
URL url = new URL(urlStr);
conn = SslTrustAll.open(url, AUTH_CONNECT_TIMEOUT_MS, 0);
trackConnection(conn);
conn.setRequestMethod("POST");
conn.setDoOutput(true);
conn.setRequestProperty("Content-Type", "application/json");
conn.setRequestProperty("Accept", "text/event-stream");
conn.setRequestProperty("Authorization", "Bearer " + token);
String body = toJson(payload);
try (OutputStream os = conn.getOutputStream()) {
os.write(body.getBytes(StandardCharsets.UTF_8));
}
int code = conn.getResponseCode();
InputStream is = (code >= 200 && code < 300) ? conn.getInputStream() : conn.getErrorStream();
if (is == null) {
throw new IOException("No response body (HTTP " + code + ")");
}
try (BufferedReader br = new BufferedReader(new InputStreamReader(is, StandardCharsets.UTF_8))) {
String line;
while ((line = br.readLine()) != null) {
if (Thread.currentThread().isInterrupted()) {
break;
}
// SSE format: "data: {json}"
if (line.startsWith("data:")) {
String json = line.substring("data:".length()).trim();
if (!json.isEmpty()) {
String type = SimpleJson.extractStringField(json, "type");
String msg = SimpleJson.extractStringField(json, "message");
listener.onEvent(type, msg, json);
if ("done".equals(type)) {
break;
}
}
}
}
}
if (Thread.currentThread().isInterrupted()) {
listener.onError("Cancelled.", null);
} else {
listener.onDone();
}
} catch (Exception e) {
if (isCancelled(e)) {
listener.onError("Cancelled.", e);
} else {
listener.onError(e.getMessage(), e);
}
} finally {
if (conn != null) {
releaseConnection(conn);
conn.disconnect();
}
}
}, "CyberStrikeAI-Stream");
worker.start();
}
void cancelByConversationId(String baseUrl, String token, String conversationId) throws IOException {
if (conversationId == null || conversationId.trim().isEmpty()) {
throw new IOException("Missing conversationId.");
}
URL url = new URL(baseUrl + "/api/agent-loop/cancel");
HttpURLConnection conn = SslTrustAll.open(url, AUTH_CONNECT_TIMEOUT_MS, AUTH_READ_TIMEOUT_MS);
conn.setRequestMethod("POST");
conn.setDoOutput(true);
conn.setRequestProperty("Content-Type", "application/json");
conn.setRequestProperty("Accept", "application/json");
conn.setRequestProperty("Authorization", "Bearer " + token);
String body = "{\"conversationId\":\"" + escapeJson(conversationId.trim()) + "\"}";
try (OutputStream os = conn.getOutputStream()) {
os.write(body.getBytes(StandardCharsets.UTF_8));
}
int code = conn.getResponseCode();
String resp = readAll(code >= 200 && code < 300 ? conn.getInputStream() : conn.getErrorStream());
if (code < 200 || code >= 300) {
String serverError = SimpleJson.extractStringField(resp, "error");
if (!serverError.isEmpty()) {
throw new IOException("Cancel failed (" + code + "): " + serverError);
}
throw new IOException("Cancel failed (" + code + ").");
}
}
private static String toJson(Map<String, Object> payload) {
String message = payload.get("message") != null ? String.valueOf(payload.get("message")) : "";
String conversationId = payload.get("conversationId") != null ? String.valueOf(payload.get("conversationId")) : "";
String role = payload.get("role") != null ? String.valueOf(payload.get("role")) : "";
StringBuilder sb = new StringBuilder();
sb.append("{");
sb.append("\"message\":\"").append(escapeJson(message)).append("\",");
sb.append("\"conversationId\":\"").append(escapeJson(conversationId)).append("\",");
sb.append("\"role\":\"").append(escapeJson(role)).append("\"");
if (payload.containsKey("orchestration") && payload.get("orchestration") != null) {
sb.append(",\"orchestration\":\"").append(escapeJson(String.valueOf(payload.get("orchestration")))).append("\"");
}
sb.append("}");
return sb.toString();
}
private static String escapeJson(String s) {
if (s == null) return "";
StringBuilder sb = new StringBuilder(s.length() + 16);
for (int i = 0; i < s.length(); i++) {
char c = s.charAt(i);
switch (c) {
case '\\': sb.append("\\\\"); break;
case '"': sb.append("\\\""); break;
case '\n': sb.append("\\n"); break;
case '\r': sb.append("\\r"); break;
case '\t': sb.append("\\t"); break;
default:
if (c < 0x20) {
sb.append(String.format("\\u%04x", (int) c));
} else {
sb.append(c);
}
}
}
return sb.toString();
}
private static String readAll(InputStream is) throws IOException {
if (is == null) return "";
try (BufferedReader br = new BufferedReader(new InputStreamReader(is, StandardCharsets.UTF_8))) {
StringBuilder sb = new StringBuilder();
String line;
while ((line = br.readLine()) != null) {
sb.append(line).append('\n');
}
return sb.toString().trim();
}
}
private static boolean looksLikeHtml(String s) {
if (s == null) return false;
String t = s.trim().toLowerCase();
return t.startsWith("<!doctype html") || t.startsWith("<html") || t.contains("<head>") || t.contains("<body");
}
}