diff --git a/mcp-servers/reverse_shell/mcp_reverse_shell.py b/mcp-servers/reverse_shell/mcp_reverse_shell.py index bd2026c8..df427764 100644 --- a/mcp-servers/reverse_shell/mcp_reverse_shell.py +++ b/mcp-servers/reverse_shell/mcp_reverse_shell.py @@ -29,6 +29,11 @@ _LISTENER_PORT: int | None = None _CLIENT_SOCK: socket.socket | None = None _CLIENT_ADDR: tuple[str, int] | None = None _LOCK = threading.Lock() +_STOP_EVENT = threading.Event() +_READY_EVENT = threading.Event() +_LAST_LISTEN_ERROR: str | None = None +_LISTENER_THREAD_JOIN_TIMEOUT = 1.0 +_START_READY_TIMEOUT = 1.5 # 用于 send_command 的输出结束标记(避免无限等待) _END_MARKER = "__RS_DONE__" @@ -62,37 +67,55 @@ def _get_local_ips() -> list[str]: def _accept_loop(port: int) -> None: """在后台线程中:bind、listen、accept,只接受一个客户端。""" - global _LISTENER, _CLIENT_SOCK, _CLIENT_ADDR, _LISTENER_PORT + global _LISTENER, _CLIENT_SOCK, _CLIENT_ADDR, _LISTENER_PORT, _LAST_LISTEN_ERROR + sock: socket.socket | None = None try: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind(("0.0.0.0", port)) sock.listen(1) + # 避免 stop_listener 关闭后 accept() 长时间不返回:用超时轮询检查停止事件 + sock.settimeout(0.5) with _LOCK: _LISTENER = sock - # 阻塞 accept,只接受一个连接 - client, addr = sock.accept() + _LISTENER_PORT = port + _LAST_LISTEN_ERROR = None + _READY_EVENT.set() + # 循环 accept:只接受一个连接,或等待 stop 事件 + while not _STOP_EVENT.is_set(): + try: + client, addr = sock.accept() + except socket.timeout: + continue + except OSError: + break + with _LOCK: + _CLIENT_SOCK = client + _CLIENT_ADDR = (addr[0], addr[1]) + break + except OSError as e: with _LOCK: - _CLIENT_SOCK = client - _CLIENT_ADDR = (addr[0], addr[1]) - except OSError: - pass + _LAST_LISTEN_ERROR = str(e) + _READY_EVENT.set() finally: with _LOCK: - if _LISTENER: - try: - _LISTENER.close() - except OSError: - pass - _LISTENER = None + _LISTENER = None _LISTENER_PORT = None + if sock is not None: + try: + sock.close() + except OSError: + pass def _start_listener(port: int) -> str: - global _LISTENER_THREAD, _LISTENER_PORT, _CLIENT_SOCK, _CLIENT_ADDR + global _LISTENER_THREAD, _LISTENER_PORT, _CLIENT_SOCK, _CLIENT_ADDR, _LAST_LISTEN_ERROR + old_thread: threading.Thread | None = None with _LOCK: - if _LISTENER is not None or (_LISTENER_THREAD is not None and _LISTENER_THREAD.is_alive()): - return f"已在监听中(端口: {_LISTENER_PORT}),请先 stop_listener 再重新 start。" + if _LISTENER is not None: + # _LISTENER_PORT 可能短暂为 None(例如刚 stop/start),因此做个兜底显示 + show_port = _LISTENER_PORT if _LISTENER_PORT is not None else port + return f"已在监听中(端口: {show_port}),请先 stop_listener 再重新 start。" if _CLIENT_SOCK is not None: try: _CLIENT_SOCK.close() @@ -100,39 +123,72 @@ def _start_listener(port: int) -> str: pass _CLIENT_SOCK = None _CLIENT_ADDR = None + old_thread = _LISTENER_THREAD + + # 若旧线程还没完全退出,短暂等待一下以减少端口绑定失败概率 + if old_thread is not None and old_thread.is_alive(): + old_thread.join(timeout=0.5) + + _STOP_EVENT.clear() + _READY_EVENT.clear() + _LAST_LISTEN_ERROR = None th = threading.Thread(target=_accept_loop, args=(port,), daemon=True) th.start() _LISTENER_THREAD = th - time.sleep(0.2) + + # 等待后台线程完成 bind/listen(或失败) + _READY_EVENT.wait(timeout=_START_READY_TIMEOUT) with _LOCK: - if _LISTENER is not None: - _LISTENER_PORT = port - ips = _get_local_ips() - addrs = ", ".join(f"{ip}:{port}" for ip in ips) - return ( - f"已在 0.0.0.0:{port} 开始监听。" - f"目标机请反弹到: {addrs}(任选其一)。连接后使用 reverse_shell_send_command 执行命令。" - ) - return f"监听 0.0.0.0:{port} 已启动(若端口被占用会失败,请检查)。" + err = _LAST_LISTEN_ERROR + listening = _LISTENER is not None + + if listening: + ips = _get_local_ips() + addrs = ", ".join(f"{ip}:{port}" for ip in ips) + return ( + f"已在 0.0.0.0:{port} 开始监听。" + f"目标机请反弹到: {addrs}(任选其一)。连接后使用 reverse_shell_send_command 执行命令。" + ) + + if err: + return f"启动监听失败(0.0.0.0:{port}):{err}" + + # 仍未准备好:可能线程调度较慢或环境异常;给出可操作的提示 + return f"启动监听未确认成功(0.0.0.0:{port})。请调用 reverse_shell_status 确认,或稍后重试。" def _stop_listener() -> str: global _LISTENER, _LISTENER_THREAD, _CLIENT_SOCK, _CLIENT_ADDR, _LISTENER_PORT + listener_sock: socket.socket | None = None + client_sock: socket.socket | None = None + old_thread: threading.Thread | None = None with _LOCK: - if _LISTENER is not None: - try: - _LISTENER.close() - except OSError: - pass - _LISTENER = None + _STOP_EVENT.set() + _READY_EVENT.set() + listener_sock = _LISTENER + old_thread = _LISTENER_THREAD + _LISTENER = None _LISTENER_PORT = None - if _CLIENT_SOCK is not None: - try: - _CLIENT_SOCK.close() - except OSError: - pass - _CLIENT_SOCK = None - _CLIENT_ADDR = None + client_sock = _CLIENT_SOCK + _CLIENT_SOCK = None + _CLIENT_ADDR = None + + if listener_sock is not None: + try: + listener_sock.close() + except OSError: + pass + if client_sock is not None: + try: + client_sock.close() + except OSError: + pass + + # 等待监听线程退出,避免 stop/start 竞态导致“端口 None 仍提示已在监听中” + if old_thread is not None and old_thread.is_alive(): + old_thread.join(timeout=_LISTENER_THREAD_JOIN_TIMEOUT) + with _LOCK: + _LISTENER_THREAD = None return "监听已停止,已断开当前客户端(如有)。"