Files
Shadowbroker/backend/services/mesh/mesh_dm_relay.py
T
BigBodyCobain 401f114e4f DM mailbox: outbound replication + receiving endpoint
Second commit on this branch (first added the per-sender cap + accept_replica
primitive). This commit wires the actual cross-node propagation:

Outbound (sender side)
----------------------
* New ``DMRelay._replicate_envelope_to_peers_async()`` — fire-and-forget
  thread that POSTs the envelope to every authenticated relay peer via
  the same per-peer HMAC pattern gate-message replication uses (#256
  ``X-Peer-Url`` + ``X-Peer-HMAC`` headers, ``resolve_peer_key_for_url``).
* ``deposit()`` now calls the replication helper after a successful
  local accept. Per-peer errors are swallowed — slow Tor peers must not
  block the sender's UX, and the recipient polling from a healthy peer
  works fine even if some peers are down.
* Metrics: dm_replication_push_ok / _rejected / _error.

Inbound (receiving side)
------------------------
* New endpoint ``POST /api/mesh/dm/replicate-envelope`` in
  routers/mesh_peer_sync.py.
* Same HMAC auth gate (``_verify_peer_push_hmac``) as the existing
  infonet/gate peer-push endpoints. Unauthenticated requests get 403.
* Body cap of 64 KB (DM envelope is bounded by MESH_DM_MAX_MSG_BYTES).
* Calls DMRelay.accept_replica which enforces the per-sender cap as a
  network rule — hostile sender's relay can hold extras locally but
  honest peers reject them on inbound replication.

End-to-end flow now works
-------------------------
  1. Alice's node accepts a deposit to Bob's mailbox (local cap check).
  2. Alice's node spawns a background thread that POSTs the envelope
     to MESH_RELAY_PEERS with per-peer HMAC.
  3. Each peer's /api/mesh/dm/replicate-envelope verifies the HMAC and
     calls accept_replica, which re-enforces the per-sender cap.
  4. Bob (offline at the time of send) eventually logs into ANY node
     in MESH_RELAY_PEERS, his existing pollDmMailboxes pulls from
     the local mailbox there, finds Alice's envelope, decrypts.

Tests
-----
backend/tests/test_dm_replicate_envelope_endpoint.py — 4 tests:

  TestReplicateEndpointAuth:
    - rejects requests without peer HMAC (403)
    - rejects requests with WRONG peer HMAC (403) — confirms the
      HMAC is actually verified, not just present
    - rejects oversize bodies (>64 KB) with 400/413

  TestReplicateEndpointRegistered:
    - static check that POST /api/mesh/dm/replicate-envelope is
      registered on app.routes — catches future refactor that
      drops the router include

All 38 backend tests touching the new code paths still pass:
  test_dm_relay_per_sender_cap.py (14)
  test_dm_replicate_envelope_endpoint.py (4)
  test_no_new_duplicate_routes.py (1) — new route is unique
  test_per_peer_secret_resolver.py (19) — HMAC primitive unaffected

What's still ahead (PR-3+)
--------------------------
* ack propagation: when recipient pulls a message on node X, peers Y/Z
  should prune their copies to free the sender's quota network-wide.
  Without this, the sender's quota frees only on the node the recipient
  actually polled — other peers still see N pending until TTL expiry.
  Workable but suboptimal. PR-3 will add a /api/mesh/dm/ack endpoint
  with the same HMAC pattern.
* recipient pull-from-peers: today the recipient's poll only hits
  their own node's relay. If they log into a peer they didn't deposit
  with, they need a way to fetch envelopes from other peers in
  MESH_RELAY_PEERS. Today this works as long as the recipient's
  current node is one of the peers Alice's node pushed to — which is
  true in a fully-meshed deployment but not guaranteed for partial
  meshes. PR-4 if telemetry shows this matters.
2026-05-22 19:23:09 -06:00

2006 lines
86 KiB
Python

"""Metadata-minimized DM relay for request and shared mailboxes.
This relay never decrypts application payloads. In secure mode it keeps
pending ciphertext in memory only and persists just the minimum metadata
needed for continuity: accepted DH bundles, block lists, witness data,
and nonce replay windows.
"""
from __future__ import annotations
import atexit
import hashlib
import json
import logging
import os
import secrets
import threading
import time
from collections import OrderedDict, defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from services.config import get_settings
from services.mesh.mesh_metrics import increment as metrics_inc
from services.mesh.mesh_wormhole_prekey import (
_validate_bundle_record,
transparency_fingerprint_for_bundle_record,
)
from services.mesh.mesh_secure_storage import read_secure_json, write_secure_json
TTL_SECONDS = 3600
EPOCH_SECONDS = 6 * 60 * 60
BACKEND_DIR = Path(__file__).resolve().parents[2]
DEFAULT_DATA_DIR = BACKEND_DIR / "data"
DEFAULT_RELAY_FILE = DEFAULT_DATA_DIR / "dm_relay.json"
DATA_DIR = DEFAULT_DATA_DIR
RELAY_FILE = DEFAULT_RELAY_FILE
logger = logging.getLogger(__name__)
def _stable_json(value: Any) -> str:
return json.dumps(value, sort_keys=True, separators=(",", ":"))
def _get_token_pepper() -> str:
"""Read token pepper lazily so auto-generated values from startup audit take effect."""
pepper = os.environ.get("MESH_DM_TOKEN_PEPPER", "").strip()
if not pepper:
try:
from services.config import get_settings
from services.env_check import _ensure_dm_token_pepper
pepper = _ensure_dm_token_pepper(get_settings())
except Exception:
pepper = os.environ.get("MESH_DM_TOKEN_PEPPER", "").strip()
if not pepper:
raise RuntimeError("MESH_DM_TOKEN_PEPPER is unavailable at runtime")
return pepper
@dataclass
class DMMessage:
sender_id: str
ciphertext: str
timestamp: float
msg_id: str
delivery_class: str
sender_seal: str = ""
relay_salt: str = ""
sender_block_ref: str = ""
payload_format: str = "dm1"
session_welcome: str = ""
class DMRelay:
"""Relay for encrypted request/shared mailboxes."""
def __init__(self) -> None:
self._lock = threading.RLock()
self._mailboxes: dict[str, list[DMMessage]] = defaultdict(list)
self._dh_keys: dict[str, dict[str, Any]] = {}
self._prekey_bundles: dict[str, dict[str, Any]] = {}
self._mailbox_bindings: dict[str, dict[str, Any]] = defaultdict(dict)
self._witnesses: dict[str, list[dict[str, Any]]] = defaultdict(list)
self._blocks: dict[str, set[str]] = defaultdict(set)
self._nonce_caches: dict[str, OrderedDict[str, float]] = {}
"""Per-agent nonce replay caches — keyed by agent_id, values are OrderedDicts of nonce→expiry."""
self._prekey_lookup_aliases: dict[str, dict[str, Any]] = {}
"""Invite-scoped lookup handle → agent_id for prekey bundle fetch without stable identity."""
self._stats: dict[str, int] = {"messages_in_memory": 0}
self._dirty = False
self._save_timer: threading.Timer | None = None
self._last_persist_error = ""
self._SAVE_INTERVAL = 5.0
atexit.register(self._flush)
self._load()
def _settings(self):
return get_settings()
def _persist_spool_enabled(self) -> bool:
return bool(self._settings().MESH_DM_PERSIST_SPOOL)
def _relay_file(self) -> Path:
# Unit tests frequently monkeypatch the module-level relay file so each
# relay instance stays isolated from the shared runtime spool path.
module_override = Path(RELAY_FILE)
if module_override != DEFAULT_RELAY_FILE:
return module_override.expanduser().resolve()
override = str(getattr(self._settings(), "MESH_DM_RELAY_FILE_PATH", "") or "").strip()
if override:
override_path = Path(override).expanduser()
if not override_path.is_absolute():
override_path = BACKEND_DIR / override_path
return override_path.resolve()
return RELAY_FILE
def _relay_data_dir(self) -> Path:
return self._relay_file().parent
def _auto_reload_enabled(self) -> bool:
if Path(RELAY_FILE) != DEFAULT_RELAY_FILE:
return False
return bool(getattr(self._settings(), "MESH_DM_RELAY_AUTO_RELOAD", False))
def _refresh_from_shared_relay(self) -> None:
if self._auto_reload_enabled():
self._reload_snapshot_from_shared_relay()
def _reload_snapshot_from_shared_relay(self) -> None:
relay_file = self._relay_file()
fresh_mailboxes: defaultdict[str, list[DMMessage]] = defaultdict(list)
fresh_dh_keys: dict[str, dict[str, Any]] = {}
fresh_prekey_bundles: dict[str, dict[str, Any]] = {}
fresh_mailbox_bindings: defaultdict[str, dict[str, Any]] = defaultdict(dict)
fresh_witnesses: defaultdict[str, list[dict[str, Any]]] = defaultdict(list)
fresh_blocks: defaultdict[str, set[str]] = defaultdict(set)
fresh_nonce_caches: dict[str, OrderedDict[str, float]] = {}
fresh_prekey_lookup_aliases: dict[str, dict[str, Any]] = {}
fresh_stats: dict[str, int] = {"messages_in_memory": 0}
current_mailboxes = defaultdict(list, {k: list(v) for k, v in self._mailboxes.items()})
current_bindings = defaultdict(
dict,
{
str(agent_id): {
str(kind): dict(entry)
for kind, entry in bindings.items()
if isinstance(entry, dict)
}
for agent_id, bindings in self._mailbox_bindings.items()
if isinstance(bindings, dict)
},
)
if not relay_file.exists():
if not self._persist_spool_enabled():
fresh_mailboxes = current_mailboxes
if not self._metadata_persist_enabled():
fresh_mailbox_bindings = current_bindings
self._mailboxes = fresh_mailboxes
self._dh_keys = fresh_dh_keys
self._prekey_bundles = fresh_prekey_bundles
self._mailbox_bindings = fresh_mailbox_bindings
self._witnesses = fresh_witnesses
self._blocks = fresh_blocks
self._nonce_caches = fresh_nonce_caches
self._prekey_lookup_aliases = fresh_prekey_lookup_aliases
self._stats = fresh_stats
return
try:
data = read_secure_json(relay_file, lambda: {})
except Exception:
return
if self._persist_spool_enabled():
mailboxes = data.get("mailboxes", {})
if isinstance(mailboxes, dict):
for key, items in mailboxes.items():
if not isinstance(items, list):
continue
restored: list[DMMessage] = []
for item in items:
try:
restored.append(
DMMessage(
sender_id=str(item.get("sender_id", "")),
ciphertext=str(item.get("ciphertext", "")),
timestamp=float(item.get("timestamp", 0)),
msg_id=str(item.get("msg_id", "")),
delivery_class=str(item.get("delivery_class", "shared")),
sender_seal=str(item.get("sender_seal", "")),
relay_salt=str(item.get("relay_salt", "") or ""),
sender_block_ref=str(item.get("sender_block_ref", "") or ""),
payload_format=str(item.get("payload_format", item.get("format", "dm1")) or "dm1"),
session_welcome=str(item.get("session_welcome", "") or ""),
)
)
except Exception:
continue
for message in restored:
if not message.sender_block_ref:
message.sender_block_ref = self._message_block_ref(message)
if restored:
fresh_mailboxes[str(key)] = restored
else:
if not self._persist_spool_enabled():
fresh_mailboxes = current_mailboxes
dh_keys = data.get("dh_keys", {})
if isinstance(dh_keys, dict):
fresh_dh_keys = {str(k): dict(v) for k, v in dh_keys.items() if isinstance(v, dict)}
prekey_bundles = data.get("prekey_bundles", {})
if isinstance(prekey_bundles, dict):
fresh_prekey_bundles = {
str(k): dict(v) for k, v in prekey_bundles.items() if isinstance(v, dict)
}
prekey_lookup_aliases = data.get("prekey_lookup_aliases", {})
if isinstance(prekey_lookup_aliases, dict):
for key, value in prekey_lookup_aliases.items():
handle = str(key or "").strip()
record = self._coerce_prekey_lookup_alias_record(value)
if handle and record:
fresh_prekey_lookup_aliases[handle] = record
now = time.time()
mailbox_bindings = data.get("mailbox_bindings", {})
if isinstance(mailbox_bindings, dict) and self._metadata_persist_enabled():
for agent_id, bindings in mailbox_bindings.items():
if not isinstance(bindings, dict):
continue
restored_agent: dict[str, dict[str, Any]] = {}
for kind, entry in bindings.items():
token_hash = ""
last_used = now
if isinstance(entry, dict):
token_hash = str(entry.get("token_hash", "") or "").strip()
last_used = float(entry.get("last_used", now) or now)
else:
token_hash = str(entry or "").strip()
if token_hash:
normalized = self._coerce_mailbox_binding_entry(
{
"token_hash": token_hash,
"bound_at": float(entry.get("bound_at", last_used) or last_used)
if isinstance(entry, dict)
else last_used,
"last_used": last_used,
"expires_at": float(entry.get("expires_at", 0) or 0)
if isinstance(entry, dict)
else 0,
},
now=now,
)
if normalized:
restored_agent[str(kind)] = normalized
if restored_agent:
fresh_mailbox_bindings[str(agent_id)] = restored_agent
elif not self._metadata_persist_enabled():
fresh_mailbox_bindings = current_bindings
witnesses = data.get("witnesses", {})
if isinstance(witnesses, dict):
fresh_witnesses = defaultdict(
list,
{str(k): list(v) for k, v in witnesses.items() if isinstance(v, list)},
)
blocks = data.get("blocks", {})
if isinstance(blocks, dict):
for key, values in blocks.items():
if isinstance(values, list):
fresh_blocks[str(key)] = {
self._canonical_blocked_id(str(v))
for v in values
if str(v or "").strip()
}
nonce_caches = data.get("nonce_caches", {})
if isinstance(nonce_caches, dict) and nonce_caches:
for aid, entries in nonce_caches.items():
if not isinstance(entries, dict):
continue
restored = sorted(
((str(k), float(v)) for k, v in entries.items() if float(v) > now),
key=lambda item: item[1],
)
if restored:
fresh_nonce_caches[str(aid)] = OrderedDict(restored)
else:
nonce_cache = data.get("nonce_cache", {})
if isinstance(nonce_cache, dict):
for composite_key, expiry in nonce_cache.items():
if float(expiry) <= now:
continue
parts = str(composite_key).split(":", 1)
if len(parts) == 2:
aid, nonce_val = parts
if aid not in fresh_nonce_caches:
fresh_nonce_caches[aid] = OrderedDict()
fresh_nonce_caches[aid][nonce_val] = float(expiry)
stats = data.get("stats", {})
if isinstance(stats, dict):
fresh_stats = {str(k): int(v) for k, v in stats.items() if isinstance(v, (int, float))}
self._mailboxes = fresh_mailboxes
self._dh_keys = fresh_dh_keys
self._prekey_bundles = fresh_prekey_bundles
self._mailbox_bindings = fresh_mailbox_bindings
self._witnesses = fresh_witnesses
self._blocks = fresh_blocks
self._nonce_caches = fresh_nonce_caches
self._prekey_lookup_aliases = fresh_prekey_lookup_aliases
self._stats = fresh_stats
self._stats["messages_in_memory"] = sum(len(v) for v in self._mailboxes.values())
if self._prune_stale_metadata():
self._dirty = True
def _request_mailbox_limit(self) -> int:
return max(1, int(self._settings().MESH_DM_REQUEST_MAILBOX_LIMIT))
def _shared_mailbox_limit(self) -> int:
return max(1, int(self._settings().MESH_DM_SHARED_MAILBOX_LIMIT))
def _self_mailbox_limit(self) -> int:
return max(1, int(self._settings().MESH_DM_SELF_MAILBOX_LIMIT))
def _per_sender_pending_limit(self) -> int:
"""Anti-spam cap on UNACKED messages a single sender can have parked
in a single recipient mailbox at any one time. See ``config.py``
``MESH_DM_PENDING_PER_SENDER_LIMIT`` for the threat model — this
rule is enforced both at ``deposit`` (local) and at
``accept_replica`` (peer push acceptance), making it a network
rule rather than a client-side honor system."""
try:
limit = int(getattr(self._settings(), "MESH_DM_PENDING_PER_SENDER_LIMIT", 2) or 2)
except (TypeError, ValueError):
limit = 2
return max(1, limit)
def _per_sender_pending_count(
self,
*,
mailbox_key: str,
sender_block_ref: str,
) -> int:
"""Count UNACKED messages from ``sender_block_ref`` currently parked
in ``mailbox_key``. Caller already holds ``self._lock``.
Messages that have been claimed/acked are removed from the mailbox
list (see ``claim_message_ids``), so anything still here is by
definition unacked. We count by exact ``sender_block_ref`` match
— that's the per-pair sender identity used for blocking too, so
the cap is naturally per-(sender, recipient).
"""
if not mailbox_key or not sender_block_ref:
return 0
messages = self._mailboxes.get(mailbox_key, [])
return sum(1 for m in messages if m.sender_block_ref == sender_block_ref)
def _nonce_ttl_seconds(self) -> int:
return max(30, int(self._settings().MESH_DM_NONCE_TTL_S))
def _nonce_cache_max_entries(self) -> int:
return max(1, int(getattr(self._settings(), "MESH_DM_NONCE_CACHE_MAX", 4096)))
def _nonce_per_agent_max(self) -> int:
return max(1, int(getattr(self._settings(), "MESH_DM_NONCE_PER_AGENT_MAX", 256)))
def _dm_key_ttl_seconds(self) -> int:
return max(1, int(getattr(self._settings(), "MESH_DM_KEY_TTL_DAYS", 30) or 30)) * 86400
def _prekey_lookup_alias_ttl_seconds(self) -> int:
return max(
1,
int(getattr(self._settings(), "MESH_DM_PREKEY_LOOKUP_ALIAS_TTL_DAYS", 14) or 14),
) * 86400
def _witness_ttl_seconds(self) -> int:
return max(1, int(getattr(self._settings(), "MESH_DM_WITNESS_TTL_DAYS", 14) or 14)) * 86400
def _mailbox_binding_ttl_seconds(self) -> int:
return max(1, int(getattr(self._settings(), "MESH_DM_BINDING_TTL_DAYS", 3) or 3)) * 86400
def _mailbox_binding_idle_ttl_seconds(self) -> int:
return min(self._mailbox_binding_ttl_seconds(), 12 * 60 * 60)
def _mailbox_binding_refresh_after_seconds(self) -> int:
return max(15 * 60, min(self._mailbox_binding_ttl_seconds(), 12 * 60 * 60))
def _mailbox_binding_expires_at(self, entry: dict[str, Any]) -> float:
bound_at = float(entry.get("bound_at", 0) or 0)
last_used = float(entry.get("last_used", bound_at) or bound_at)
if bound_at <= 0:
return 0.0
return min(
bound_at + self._mailbox_binding_ttl_seconds(),
last_used + self._mailbox_binding_idle_ttl_seconds(),
)
def _coerce_mailbox_binding_entry(self, entry: Any, *, now: float | None = None) -> dict[str, Any]:
current = time.time() if now is None else float(now)
token_hash = ""
bound_at = current
last_used = current
if isinstance(entry, dict):
token_hash = str(entry.get("token_hash", "") or "").strip()
bound_at = float(entry.get("bound_at", entry.get("last_used", current)) or current)
last_used = float(entry.get("last_used", bound_at) or bound_at)
else:
token_hash = str(entry or "").strip()
if not token_hash:
return {}
normalized = {
"token_hash": token_hash,
"bound_at": bound_at,
"last_used": last_used,
}
normalized["expires_at"] = self._mailbox_binding_expires_at(normalized)
return normalized
def _alias_updated_at_for_agent(self, agent_id: str) -> float:
stored = self._prekey_bundles.get(str(agent_id or "").strip(), {})
if isinstance(stored, dict):
return float(stored.get("updated_at", stored.get("timestamp", time.time())) or time.time())
return float(time.time())
def _make_prekey_lookup_alias_record(
self,
agent_id: str,
*,
updated_at: float | None = None,
expires_at: int = 0,
max_uses: int = 0,
use_count: int = 0,
last_used_at: float = 0,
) -> dict[str, Any]:
aid = str(agent_id or "").strip()
if not aid:
return {}
if updated_at is None:
updated_at = self._alias_updated_at_for_agent(aid)
return {
"agent_id": aid,
"updated_at": float(updated_at or self._alias_updated_at_for_agent(aid)),
"expires_at": max(0, int(expires_at or 0)),
"max_uses": max(0, int(max_uses or 0)),
"use_count": max(0, int(use_count or 0)),
"last_used_at": float(last_used_at or 0),
}
def _coerce_prekey_lookup_alias_record(self, value: Any) -> dict[str, Any]:
if isinstance(value, dict):
aid = str(value.get("agent_id", "") or "").strip()
if not aid:
return {}
updated_at = float(
value.get("updated_at", value.get("last_used", self._alias_updated_at_for_agent(aid)))
or self._alias_updated_at_for_agent(aid)
)
return self._make_prekey_lookup_alias_record(
aid,
updated_at=updated_at,
expires_at=int(value.get("expires_at", 0) or 0),
max_uses=int(value.get("max_uses", 0) or 0),
use_count=int(value.get("use_count", value.get("uses", 0)) or 0),
last_used_at=float(value.get("last_used_at", value.get("last_used", 0)) or 0),
)
aid = str(value or "").strip()
if not aid:
return {}
return self._make_prekey_lookup_alias_record(aid)
def _resolve_prekey_lookup_alias(self, lookup_token: str) -> str:
handle = str(lookup_token or "").strip()
if not handle:
return ""
record = self._coerce_prekey_lookup_alias_record(self._prekey_lookup_aliases.get(handle, {}))
if not record:
return ""
now = time.time()
expires_at = int(record.get("expires_at", 0) or 0)
max_uses = int(record.get("max_uses", 0) or 0)
use_count = int(record.get("use_count", 0) or 0)
if (expires_at > 0 and now > expires_at) or (max_uses > 0 and use_count >= max_uses):
self._prekey_lookup_aliases.pop(handle, None)
self._save()
return ""
updated = self._make_prekey_lookup_alias_record(
str(record.get("agent_id", "") or "").strip(),
updated_at=float(record.get("updated_at", self._alias_updated_at_for_agent(str(record.get("agent_id", "") or "").strip())) or now),
expires_at=expires_at,
max_uses=max_uses,
use_count=use_count + 1,
last_used_at=now,
)
self._prekey_lookup_aliases[handle] = updated
self._save()
try:
from services.mesh.mesh_wormhole_identity import record_prekey_lookup_handle_use
record_prekey_lookup_handle_use(handle, now=int(now))
except Exception:
pass
return str(updated.get("agent_id", "") or "").strip()
def _pepper_token(self, token: str) -> str:
material = token
pepper = _get_token_pepper()
if pepper:
material = f"{pepper}|{token}"
return hashlib.sha256(material.encode("utf-8")).hexdigest()
def _legacy_sender_block_ref(self, sender_id: str) -> str:
sender = str(sender_id or "").strip()
if not sender:
return ""
return "ref:" + self._pepper_token(f"block|{sender}")
def _sender_block_scope(
self,
*,
recipient_id: str = "",
recipient_token: str = "",
delivery_class: str = "",
) -> str:
recipient = str(recipient_id or "").strip()
if recipient:
return f"recipient|{recipient}"
token = str(recipient_token or "").strip()
if token and str(delivery_class or "").strip().lower() == "shared":
return f"shared|{self._hashed_mailbox_token(token)}"
return ""
def _sender_block_ref(self, sender_id: str, *, scope: str = "") -> str:
sender = str(sender_id or "").strip()
if not sender:
return ""
material = f"block|{scope}|{sender}" if scope else f"block|{sender}"
return "ref:" + self._pepper_token(material)
def _sender_block_refs(
self,
sender_id: str,
*,
recipient_id: str = "",
recipient_token: str = "",
delivery_class: str = "",
) -> set[str]:
refs: set[str] = set()
legacy = self._legacy_sender_block_ref(sender_id)
if legacy:
refs.add(legacy)
scoped = self._sender_block_ref(
sender_id,
scope=self._sender_block_scope(
recipient_id=recipient_id,
recipient_token=recipient_token,
delivery_class=delivery_class,
),
)
if scoped:
refs.add(scoped)
return refs
def _canonical_blocked_id(self, blocked_id: str, *, scope: str = "") -> str:
blocked = str(blocked_id or "").strip()
if not blocked:
return ""
if blocked.startswith("ref:"):
return blocked
return self._sender_block_ref(blocked, scope=scope)
def _message_block_ref(self, message: DMMessage) -> str:
block_ref = str(getattr(message, "sender_block_ref", "") or "").strip()
if block_ref:
return block_ref
sender_id = str(message.sender_id or "").strip()
if not sender_id or sender_id.startswith("sealed:") or sender_id.startswith("sender_token:"):
return ""
return self._legacy_sender_block_ref(sender_id)
def _mailbox_key(self, mailbox_type: str, mailbox_value: str, epoch: int | None = None) -> str:
if mailbox_type in {"self", "requests"}:
bucket = self._epoch_bucket() if epoch is None else int(epoch)
identifier = f"{mailbox_type}|{bucket}|{mailbox_value}"
else:
identifier = f"{mailbox_type}|{mailbox_value}"
return self._pepper_token(identifier)
def _hashed_mailbox_token(self, token: str) -> str:
return hashlib.sha256(str(token or "").encode("utf-8")).hexdigest()
def _remember_mailbox_binding(self, agent_id: str, mailbox_type: str, token: str) -> str:
if self._prune_stale_mailbox_bindings():
self._save()
now = time.time()
agent_key = str(agent_id or "").strip()
mailbox_key = str(mailbox_type or "").strip().lower()
token_hash = self._hashed_mailbox_token(token)
current = self._coerce_mailbox_binding_entry(
self._mailbox_bindings.get(agent_key, {}).get(mailbox_key, {}),
now=now,
)
refreshed = {
"token_hash": token_hash,
"bound_at": now,
"last_used": now,
}
if current and str(current.get("token_hash", "") or "") == token_hash:
refreshed["bound_at"] = float(current.get("bound_at", now) or now)
if (now - refreshed["bound_at"]) >= self._mailbox_binding_refresh_after_seconds():
refreshed["bound_at"] = now
refreshed["expires_at"] = self._mailbox_binding_expires_at(refreshed)
self._mailbox_bindings[agent_key][mailbox_key] = refreshed
self._save()
return token_hash
def _bound_mailbox_key(self, agent_id: str, mailbox_type: str) -> str:
if self._prune_stale_mailbox_bindings():
self._save()
agent_key = str(agent_id or "").strip()
mailbox_key = str(mailbox_type or "").strip().lower()
entry = self._mailbox_bindings.get(agent_key, {}).get(
mailbox_key,
{},
)
normalized = self._coerce_mailbox_binding_entry(entry)
if normalized and normalized != entry:
self._mailbox_bindings[agent_key][mailbox_key] = normalized
self._save()
return str(normalized.get("token_hash", "") or "")
def _mailbox_keys_for_claim(self, agent_id: str, claim: dict[str, Any]) -> list[str]:
claim_type = str(claim.get("type", "")).strip().lower()
if claim_type == "shared":
token = str(claim.get("token", "")).strip()
if not token:
metrics_inc("dm_claim_invalid")
return []
return [self._hashed_mailbox_token(token)]
if claim_type == "requests":
token = str(claim.get("token", "")).strip()
if token:
previous_bound = self._bound_mailbox_key(agent_id, "requests")
bound_key = self._remember_mailbox_binding(agent_id, "requests", token)
epoch = self._epoch_bucket()
return [
key
for key in [
previous_bound,
bound_key,
self._mailbox_key("requests", agent_id, epoch),
self._mailbox_key("requests", agent_id, epoch - 1),
]
if key
]
metrics_inc("dm_claim_invalid")
return []
if claim_type == "self":
token = str(claim.get("token", "")).strip()
if token:
previous_bound = self._bound_mailbox_key(agent_id, "self")
bound_key = self._remember_mailbox_binding(agent_id, "self", token)
epoch = self._epoch_bucket()
return [
key
for key in [
previous_bound,
bound_key,
self._mailbox_key("self", agent_id, epoch),
self._mailbox_key("self", agent_id, epoch - 1),
]
if key
]
metrics_inc("dm_claim_invalid")
return []
metrics_inc("dm_claim_invalid")
return []
def mailbox_key_for_delivery(
self,
*,
recipient_id: str,
delivery_class: str,
recipient_token: str | None = None,
) -> str:
with self._lock:
self._refresh_from_shared_relay()
delivery_class = str(delivery_class or "").strip().lower()
if delivery_class == "request":
bound_key = self._bound_mailbox_key(recipient_id, "requests")
if bound_key:
return bound_key
return self._mailbox_key("requests", str(recipient_id or "").strip())
if delivery_class == "shared":
token = str(recipient_token or "").strip()
if not token:
raise ValueError("recipient_token required for shared delivery")
return self._hashed_mailbox_token(token)
raise ValueError("Unsupported delivery_class")
def claim_mailbox_keys(self, agent_id: str, claims: list[dict[str, Any]]) -> list[str]:
with self._lock:
self._refresh_from_shared_relay()
if self._prune_stale_mailbox_bindings():
self._save()
keys: list[str] = []
for claim in claims[:32]:
keys.extend(self._mailbox_keys_for_claim(agent_id, claim))
return list(dict.fromkeys(keys))
def _legacy_mailbox_token(self, agent_id: str, epoch: int) -> str:
raw = f"sb_dm|{epoch}|{agent_id}".encode("utf-8")
return hashlib.sha256(raw).hexdigest()
def _legacy_token_candidates(self, agent_id: str) -> list[str]:
epoch = self._epoch_bucket()
raw = [self._legacy_mailbox_token(agent_id, epoch), self._legacy_mailbox_token(agent_id, epoch - 1)]
peppered = [self._pepper_token(token) for token in raw]
return list(dict.fromkeys(peppered + raw))
def _save(self) -> None:
"""Mark dirty and schedule a coalesced disk write."""
self._dirty = True
relay_file = self._relay_file()
if self._auto_reload_enabled() or not relay_file.exists() or self._persist_failures_are_fatal():
self._flush()
return
with self._lock:
if self._save_timer is None or not self._save_timer.is_alive():
self._save_timer = threading.Timer(self._SAVE_INTERVAL, self._flush)
self._save_timer.daemon = True
self._save_timer.start()
def _persist_failures_are_fatal(self) -> bool:
return bool(os.environ.get("PYTEST_CURRENT_TEST", "").strip())
def _record_persist_failure(self, operation: str, exc: Exception) -> None:
self._last_persist_error = f"{operation}:{type(exc).__name__}:{exc}"
metrics_inc("dm_relay_persist_failure")
logger.exception("dm relay %s failed for %s", operation, self._relay_file())
def _prune_stale_metadata(self) -> bool:
"""Remove expired relay metadata that should not outlive its retention window."""
now = time.time()
key_ttl = self._dm_key_ttl_seconds()
alias_ttl = self._prekey_lookup_alias_ttl_seconds()
witness_ttl = self._witness_ttl_seconds()
changed = False
stale_keys = [
aid for aid, entry in self._dh_keys.items()
if (now - float(entry.get("timestamp", 0) or 0)) > key_ttl
]
for aid in stale_keys:
del self._dh_keys[aid]
changed = True
stale_bundles = [
aid for aid, entry in self._prekey_bundles.items()
if (now - float(entry.get("updated_at", entry.get("timestamp", 0)) or 0)) > key_ttl
]
for aid in stale_bundles:
del self._prekey_bundles[aid]
changed = True
stale_aliases: list[str] = []
for alias, value in list(self._prekey_lookup_aliases.items()):
record = self._coerce_prekey_lookup_alias_record(value)
if not record:
stale_aliases.append(alias)
continue
if self._prekey_lookup_aliases.get(alias) != record:
self._prekey_lookup_aliases[alias] = record
changed = True
target = str(record.get("agent_id", "") or "").strip()
updated_at = float(record.get("updated_at", self._alias_updated_at_for_agent(target)) or 0)
expires_at = int(record.get("expires_at", 0) or 0)
max_uses = int(record.get("max_uses", 0) or 0)
use_count = int(record.get("use_count", 0) or 0)
if (
not target
or target not in self._prekey_bundles
or (now - updated_at) > alias_ttl
or (expires_at > 0 and now > float(expires_at))
or (max_uses > 0 and use_count >= max_uses)
):
stale_aliases.append(alias)
for alias in stale_aliases:
del self._prekey_lookup_aliases[alias]
changed = True
for target_id in list(self._witnesses):
fresh = [
witness
for witness in self._witnesses.get(target_id, [])
if (now - float(witness.get("timestamp", 0) or 0)) <= witness_ttl
]
if len(fresh) != len(self._witnesses.get(target_id, [])):
changed = True
if fresh:
self._witnesses[target_id] = fresh
else:
del self._witnesses[target_id]
if self._prune_stale_mailbox_bindings(now=now):
changed = True
return changed
def _prune_stale_mailbox_bindings(self, *, now: float | None = None) -> bool:
current = time.time() if now is None else now
changed = False
stale_agents: list[str] = []
for agent_id, kinds in self._mailbox_bindings.items():
normalized_updates: dict[str, dict[str, Any]] = {}
expired_kinds = [
k
for k, v in kinds.items()
if not self._coerce_mailbox_binding_entry(v, now=current)
or current > self._mailbox_binding_expires_at(
self._coerce_mailbox_binding_entry(v, now=current)
)
]
for kind, entry in list(kinds.items()):
normalized = self._coerce_mailbox_binding_entry(entry, now=current)
if normalized and normalized != entry:
normalized_updates[kind] = normalized
for kind, normalized in normalized_updates.items():
kinds[kind] = normalized
changed = True
for k in expired_kinds:
del kinds[k]
changed = True
if not kinds:
stale_agents.append(agent_id)
for agent_id in stale_agents:
del self._mailbox_bindings[agent_id]
changed = True
return changed
def _metadata_persist_enabled(self) -> bool:
settings = self._settings()
return bool(getattr(settings, "MESH_DM_METADATA_PERSIST", False)) and bool(
getattr(settings, "MESH_DM_METADATA_PERSIST_ACKNOWLEDGE", False)
)
def _flush(self) -> None:
"""Actually write to disk (called by timer or atexit)."""
if not self._dirty:
return
try:
self._prune_stale_metadata()
relay_file = self._relay_file()
self._relay_data_dir().mkdir(parents=True, exist_ok=True)
payload: dict[str, Any] = {
"saved_at": int(time.time()),
"dh_keys": self._dh_keys,
"prekey_bundles": self._prekey_bundles,
"prekey_lookup_aliases": self._prekey_lookup_aliases,
"witnesses": self._witnesses,
"blocks": {k: sorted(v) for k, v in self._blocks.items()},
"nonce_caches": {aid: dict(c) for aid, c in self._nonce_caches.items()},
"stats": self._stats,
}
if self._metadata_persist_enabled():
payload["mailbox_bindings"] = {
agent_id: {
mailbox_type: {
"token_hash": str(entry.get("token_hash", "") or "").strip(),
"bound_at": float(entry.get("bound_at", 0) or 0),
"last_used": float(entry.get("last_used", 0) or 0),
"expires_at": float(entry.get("expires_at", 0) or 0),
}
for mailbox_type, entry in bindings.items()
if isinstance(entry, dict) and str(entry.get("token_hash", "") or "").strip()
}
for agent_id, bindings in self._mailbox_bindings.items()
if isinstance(bindings, dict)
}
if self._persist_spool_enabled():
payload["mailboxes"] = {
key: [m.__dict__ for m in msgs] for key, msgs in self._mailboxes.items()
}
write_secure_json(relay_file, payload)
self._dirty = False
self._last_persist_error = ""
except Exception as exc:
self._record_persist_failure("flush", exc)
if self._persist_failures_are_fatal():
raise
def _load(self) -> None:
relay_file = self._relay_file()
if not relay_file.exists():
return
try:
data = read_secure_json(relay_file, lambda: {})
except Exception:
return
if self._persist_spool_enabled():
mailboxes = data.get("mailboxes", {})
if isinstance(mailboxes, dict):
for key, items in mailboxes.items():
if not isinstance(items, list):
continue
restored: list[DMMessage] = []
for item in items:
try:
restored.append(
DMMessage(
sender_id=str(item.get("sender_id", "")),
ciphertext=str(item.get("ciphertext", "")),
timestamp=float(item.get("timestamp", 0)),
msg_id=str(item.get("msg_id", "")),
delivery_class=str(item.get("delivery_class", "shared")),
sender_seal=str(item.get("sender_seal", "")),
relay_salt=str(item.get("relay_salt", "") or ""),
sender_block_ref=str(item.get("sender_block_ref", "") or ""),
payload_format=str(item.get("payload_format", item.get("format", "dm1")) or "dm1"),
session_welcome=str(item.get("session_welcome", "") or ""),
)
)
except Exception:
continue
for message in restored:
if not message.sender_block_ref:
message.sender_block_ref = self._message_block_ref(message)
if restored:
self._mailboxes[key] = restored
dh_keys = data.get("dh_keys", {})
if isinstance(dh_keys, dict):
self._dh_keys = {str(k): dict(v) for k, v in dh_keys.items() if isinstance(v, dict)}
prekey_bundles = data.get("prekey_bundles", {})
if isinstance(prekey_bundles, dict):
self._prekey_bundles = {
str(k): dict(v) for k, v in prekey_bundles.items() if isinstance(v, dict)
}
prekey_lookup_aliases = data.get("prekey_lookup_aliases", {})
if isinstance(prekey_lookup_aliases, dict):
restored_aliases: dict[str, dict[str, Any]] = {}
alias_records_migrated = False
for key, value in prekey_lookup_aliases.items():
handle = str(key or "").strip()
record = self._coerce_prekey_lookup_alias_record(value)
if not handle or not record:
continue
restored_aliases[handle] = record
if value != record:
alias_records_migrated = True
self._prekey_lookup_aliases = restored_aliases
if alias_records_migrated:
self._dirty = True
now = time.time()
mailbox_bindings = data.get("mailbox_bindings", {})
if isinstance(mailbox_bindings, dict):
if self._metadata_persist_enabled():
restored_bindings: dict[str, dict[str, dict[str, Any]]] = {}
for agent_id, bindings in mailbox_bindings.items():
if not isinstance(bindings, dict):
continue
restored_agent: dict[str, dict[str, Any]] = {}
for kind, entry in bindings.items():
token_hash = ""
last_used = now
if isinstance(entry, dict):
token_hash = str(entry.get("token_hash", "") or "").strip()
last_used = float(entry.get("last_used", now) or now)
else:
token_hash = str(entry or "").strip()
if not token_hash:
continue
normalized = self._coerce_mailbox_binding_entry(
{
"token_hash": token_hash,
"bound_at": float(entry.get("bound_at", last_used) or last_used)
if isinstance(entry, dict)
else last_used,
"last_used": last_used,
"expires_at": float(entry.get("expires_at", 0) or 0)
if isinstance(entry, dict)
else 0,
},
now=now,
)
if normalized:
restored_agent[str(kind)] = normalized
if restored_agent:
restored_bindings[str(agent_id)] = restored_agent
self._mailbox_bindings = defaultdict(dict, restored_bindings)
elif mailbox_bindings:
# Old relay files may still contain persisted mailbox bindings.
# When metadata persistence is disabled we intentionally do not
# restore them, and mark dirty so the next flush rewrites the
# relay state without that graph metadata.
self._dirty = True
witnesses = data.get("witnesses", {})
if isinstance(witnesses, dict):
self._witnesses = defaultdict(
list,
{
str(k): list(v)
for k, v in witnesses.items()
if isinstance(v, list)
},
)
blocks = data.get("blocks", {})
if isinstance(blocks, dict):
for key, values in blocks.items():
if isinstance(values, list):
self._blocks[str(key)] = {
self._canonical_blocked_id(str(v))
for v in values
if str(v or "").strip()
}
nonce_caches = data.get("nonce_caches", {})
if isinstance(nonce_caches, dict) and nonce_caches:
for aid, entries in nonce_caches.items():
if not isinstance(entries, dict):
continue
restored = sorted(
((str(k), float(v)) for k, v in entries.items() if float(v) > now),
key=lambda item: item[1],
)
if restored:
self._nonce_caches[str(aid)] = OrderedDict(restored)
else:
# Backward compatibility: migrate flat nonce_cache → per-agent
nonce_cache = data.get("nonce_cache", {})
if isinstance(nonce_cache, dict):
for composite_key, expiry in nonce_cache.items():
if float(expiry) <= now:
continue
parts = str(composite_key).split(":", 1)
if len(parts) == 2:
aid, nonce_val = parts
if aid not in self._nonce_caches:
self._nonce_caches[aid] = OrderedDict()
self._nonce_caches[aid][nonce_val] = float(expiry)
stats = data.get("stats", {})
if isinstance(stats, dict):
self._stats = {str(k): int(v) for k, v in stats.items() if isinstance(v, (int, float))}
self._stats["messages_in_memory"] = sum(len(v) for v in self._mailboxes.values())
if self._prune_stale_metadata():
self._dirty = True
def _bundle_fingerprint(
self,
*,
dh_pub_key: str,
dh_algo: str,
public_key: str,
public_key_algo: str,
protocol_version: str,
) -> str:
material = "|".join(
[
dh_pub_key,
dh_algo,
public_key,
public_key_algo,
protocol_version,
]
)
return hashlib.sha256(material.encode("utf-8")).hexdigest()
def _advance_prekey_transparency(
self,
*,
agent_id: str,
bundle: dict[str, Any],
signature: str,
public_key: str,
public_key_algo: str,
protocol_version: str,
sequence: int,
existing: dict[str, Any] | None,
) -> dict[str, Any]:
previous_head = str((existing or {}).get("prekey_transparency_head", "") or "").strip().lower()
previous_size = int((existing or {}).get("prekey_transparency_size", 0) or 0)
publication_fingerprint = transparency_fingerprint_for_bundle_record(
{
"agent_id": agent_id,
"bundle": bundle,
"signature": signature,
"public_key": public_key,
"public_key_algo": public_key_algo,
"protocol_version": protocol_version,
"sequence": int(sequence),
}
)
next_size = previous_size + 1
head_payload = {
"agent_id": agent_id,
"sequence": int(sequence),
"signed_at": int(bundle.get("signed_at", 0) or 0),
"publication_fingerprint": publication_fingerprint,
"previous_head": previous_head,
"index": next_size,
}
head = hashlib.sha256(_stable_json(head_payload).encode("utf-8")).hexdigest()
history = list((existing or {}).get("prekey_transparency_log") or [])
history.append(
{
"index": next_size,
"sequence": int(sequence),
"signed_at": int(bundle.get("signed_at", 0) or 0),
"publication_fingerprint": publication_fingerprint,
"previous_head": previous_head,
"head": head,
"observed_at": int(time.time()),
}
)
return {
"prekey_transparency_head": head,
"prekey_transparency_size": next_size,
"prekey_transparency_fingerprint": publication_fingerprint,
"prekey_transparency_log": history[-16:],
}
def register_dh_key(
self,
agent_id: str,
dh_pub_key: str,
dh_algo: str,
timestamp: int,
signature: str,
public_key: str,
public_key_algo: str,
protocol_version: str,
sequence: int,
) -> tuple[bool, str, dict[str, Any] | None]:
"""Register/update an agent's DH public key bundle with replay protection."""
fingerprint = self._bundle_fingerprint(
dh_pub_key=dh_pub_key,
dh_algo=dh_algo,
public_key=public_key,
public_key_algo=public_key_algo,
protocol_version=protocol_version,
)
with self._lock:
self._refresh_from_shared_relay()
existing = self._dh_keys.get(agent_id)
if existing:
existing_seq = int(existing.get("sequence", 0) or 0)
existing_ts = int(existing.get("timestamp", 0) or 0)
if sequence <= existing_seq:
metrics_inc("dm_key_replay")
return False, "DM key replay or rollback rejected", None
if timestamp < existing_ts:
metrics_inc("dm_key_stale")
return False, "DM key timestamp is older than the current bundle", None
self._dh_keys[agent_id] = {
"dh_pub_key": dh_pub_key,
"dh_algo": dh_algo,
"timestamp": timestamp,
"signature": signature,
"public_key": public_key,
"public_key_algo": public_key_algo,
"protocol_version": protocol_version,
"sequence": sequence,
"bundle_fingerprint": fingerprint,
}
self._save()
return True, "ok", {
"accepted_sequence": sequence,
"bundle_fingerprint": fingerprint,
}
def get_dh_key(self, agent_id: str) -> dict[str, Any] | None:
with self._lock:
self._refresh_from_shared_relay()
self._prune_stale_metadata()
return self._dh_keys.get(agent_id)
def get_dh_key_by_lookup(self, lookup_token: str) -> tuple[dict[str, Any] | None, str]:
"""Resolve a prekey lookup alias and return the DH key for the resolved agent."""
with self._lock:
self._refresh_from_shared_relay()
self._prune_stale_metadata()
resolved_id = self._resolve_prekey_lookup_alias(lookup_token)
if not resolved_id:
return None, ""
stored = self._dh_keys.get(resolved_id)
if not stored:
return None, ""
return dict(stored), resolved_id
def register_prekey_bundle(
self,
agent_id: str,
bundle: dict[str, Any],
signature: str,
public_key: str,
public_key_algo: str,
protocol_version: str,
sequence: int,
lookup_aliases: list[Any] | None = None,
) -> tuple[bool, str, dict[str, Any] | None]:
ok, reason = _validate_bundle_record(
{
"bundle": bundle,
"public_key": public_key,
"agent_id": agent_id,
}
)
if not ok:
return False, reason, None
with self._lock:
self._refresh_from_shared_relay()
existing = self._prekey_bundles.get(agent_id)
if existing:
existing_seq = int(existing.get("sequence", 0) or 0)
if sequence <= existing_seq:
return False, "Prekey bundle replay or rollback rejected", None
transparency = self._advance_prekey_transparency(
agent_id=agent_id,
bundle=dict(bundle or {}),
signature=signature,
public_key=public_key,
public_key_algo=public_key_algo,
protocol_version=protocol_version,
sequence=int(sequence),
existing=existing,
)
stored = {
"bundle": dict(bundle or {}),
"signature": signature,
"public_key": public_key,
"public_key_algo": public_key_algo,
"protocol_version": protocol_version,
"sequence": int(sequence),
"updated_at": int(time.time()),
**transparency,
}
self._prekey_bundles[agent_id] = stored
if lookup_aliases:
alias_updated_at = float(stored.get("updated_at", time.time()) or time.time())
for alias in lookup_aliases[:16]:
alias_record = self._coerce_prekey_lookup_alias_record(
{
"agent_id": agent_id,
**dict(alias),
}
if isinstance(alias, dict)
else self._make_prekey_lookup_alias_record(agent_id, updated_at=alias_updated_at)
)
handle = str(alias.get("handle", "") if isinstance(alias, dict) else alias or "").strip()
if handle:
self._prekey_lookup_aliases[handle] = self._make_prekey_lookup_alias_record(
agent_id,
updated_at=alias_updated_at,
expires_at=int(alias_record.get("expires_at", 0) or 0),
max_uses=int(alias_record.get("max_uses", 0) or 0),
use_count=int(alias_record.get("use_count", 0) or 0),
last_used_at=float(alias_record.get("last_used_at", 0) or 0),
)
self._save()
return True, "ok", {"accepted_sequence": int(sequence), **transparency}
def get_prekey_bundle(self, agent_id: str) -> dict[str, Any] | None:
with self._lock:
self._refresh_from_shared_relay()
self._prune_stale_metadata()
stored = self._prekey_bundles.get(agent_id)
if not stored:
return None
return dict(stored)
def get_prekey_bundle_by_lookup(self, lookup_token: str) -> tuple[dict[str, Any] | None, str]:
"""Resolve a lookup alias to a prekey bundle. Returns (bundle, agent_id)."""
with self._lock:
self._refresh_from_shared_relay()
self._prune_stale_metadata()
resolved_id = self._resolve_prekey_lookup_alias(lookup_token)
if not resolved_id:
return None, ""
stored = self._prekey_bundles.get(resolved_id)
if not stored:
return None, ""
return dict(stored), resolved_id
def register_prekey_lookup_alias(
self,
alias: str,
agent_id: str,
*,
expires_at: int = 0,
max_uses: int = 0,
use_count: int = 0,
last_used_at: int = 0,
) -> None:
"""Register a lookup alias for an agent's prekey bundle."""
handle = str(alias or "").strip()
aid = str(agent_id or "").strip()
if handle and aid:
with self._lock:
self._refresh_from_shared_relay()
self._prekey_lookup_aliases[handle] = self._make_prekey_lookup_alias_record(
aid,
expires_at=expires_at,
max_uses=max_uses,
use_count=use_count,
last_used_at=last_used_at,
)
self._save()
def unregister_prekey_lookup_alias(self, alias: str) -> bool:
"""Remove an invite-scoped lookup alias from the local relay."""
handle = str(alias or "").strip()
if not handle:
return False
removed = False
with self._lock:
self._refresh_from_shared_relay()
if handle in self._prekey_lookup_aliases:
del self._prekey_lookup_aliases[handle]
removed = True
if removed:
self._save()
return removed
def consume_one_time_prekey(self, agent_id: str) -> dict[str, Any] | None:
"""Atomically claim the next published one-time prekey for a peer bundle."""
claimed: dict[str, Any] | None = None
with self._lock:
self._refresh_from_shared_relay()
stored = self._prekey_bundles.get(agent_id)
if not stored:
return None
bundle = dict(stored.get("bundle") or {})
otks = list(bundle.get("one_time_prekeys") or [])
if not otks:
return dict(stored)
claimed = dict(otks.pop(0) or {})
bundle["one_time_prekeys"] = otks
bundle["one_time_prekey_count"] = len(otks)
stored = dict(stored)
stored["bundle"] = bundle
stored["updated_at"] = int(time.time())
self._prekey_bundles[agent_id] = stored
self._save()
result = dict(stored)
result["claimed_one_time_prekey"] = claimed
return result
def _prune_witnesses(self, target_id: str) -> None:
cutoff = time.time() - self._witness_ttl_seconds()
self._witnesses[target_id] = [
w for w in self._witnesses.get(target_id, []) if float(w.get("timestamp", 0)) >= cutoff
]
if not self._witnesses[target_id]:
del self._witnesses[target_id]
def record_witness(
self,
witness_id: str,
target_id: str,
dh_pub_key: str,
timestamp: int,
) -> tuple[bool, str]:
if not witness_id or not target_id or not dh_pub_key:
return False, "Missing witness_id, target_id, or dh_pub_key"
if witness_id == target_id:
return False, "Cannot witness yourself"
with self._lock:
self._refresh_from_shared_relay()
self._prune_witnesses(target_id)
entries = self._witnesses.get(target_id, [])
for entry in entries:
if entry.get("witness_id") == witness_id and entry.get("dh_pub_key") == dh_pub_key:
return False, "Duplicate witness"
entries.append(
{
"witness_id": witness_id,
"dh_pub_key": dh_pub_key,
"timestamp": int(timestamp),
}
)
self._witnesses[target_id] = entries[-50:]
self._save()
return True, "ok"
def get_witnesses(self, target_id: str, dh_pub_key: str | None = None, limit: int = 5) -> list[dict]:
with self._lock:
self._refresh_from_shared_relay()
self._prune_witnesses(target_id)
entries = list(self._witnesses.get(target_id, []))
if dh_pub_key:
entries = [e for e in entries if e.get("dh_pub_key") == dh_pub_key]
entries = sorted(entries, key=lambda e: e.get("timestamp", 0), reverse=True)
return entries[: max(1, limit)]
def _epoch_bucket(self, ts: float | None = None) -> int:
now = ts if ts is not None else time.time()
return int(now // EPOCH_SECONDS)
def _mailbox_limit_for_class(self, delivery_class: str) -> int:
if delivery_class == "request":
return self._request_mailbox_limit()
if delivery_class == "shared":
return self._shared_mailbox_limit()
return self._self_mailbox_limit()
def _cleanup_expired(self) -> bool:
now = time.time()
changed = False
for mailbox_id in list(self._mailboxes):
fresh = [m for m in self._mailboxes[mailbox_id] if now - m.timestamp < TTL_SECONDS]
if len(fresh) != len(self._mailboxes[mailbox_id]):
changed = True
self._mailboxes[mailbox_id] = fresh
if not self._mailboxes[mailbox_id]:
del self._mailboxes[mailbox_id]
changed = True
self._stats["messages_in_memory"] = sum(len(v) for v in self._mailboxes.values())
return changed
def _total_nonce_count(self) -> int:
return sum(len(c) for c in self._nonce_caches.values())
def _trim_global_nonce_budget(self, *, preferred_agent_id: str = "") -> int:
trimmed = 0
max_entries = self._nonce_cache_max_entries()
preferred_agent_id = str(preferred_agent_id or "").strip()
while self._total_nonce_count() >= max_entries:
oldest_choice: tuple[str, str, float] | None = None
for aid, cache in self._nonce_caches.items():
if not cache:
continue
if preferred_agent_id and aid == preferred_agent_id and len(self._nonce_caches) > 1:
continue
nonce_value, expiry = next(iter(cache.items()))
if oldest_choice is None or float(expiry) < oldest_choice[2]:
oldest_choice = (aid, nonce_value, float(expiry))
if oldest_choice is None and preferred_agent_id:
for aid, cache in self._nonce_caches.items():
if not cache:
continue
nonce_value, expiry = next(iter(cache.items()))
if oldest_choice is None or float(expiry) < oldest_choice[2]:
oldest_choice = (aid, nonce_value, float(expiry))
if oldest_choice is None:
break
aid, nonce_value, _expiry = oldest_choice
cache = self._nonce_caches.get(aid)
if not cache:
continue
cache.pop(nonce_value, None)
if not cache:
self._nonce_caches.pop(aid, None)
trimmed += 1
if trimmed:
metrics_inc("dm_nonce_cache_trimmed")
return trimmed
def consume_nonce(self, agent_id: str, nonce: str, timestamp: int) -> tuple[bool, str]:
nonce = str(nonce or "").strip()
if not nonce:
return False, "Missing nonce"
agent_id = str(agent_id or "").strip()
now = time.time()
with self._lock:
self._refresh_from_shared_relay()
# Expire stale entries across all agents
for aid in list(self._nonce_caches):
cache = self._nonce_caches[aid]
self._nonce_caches[aid] = OrderedDict(
(k, exp) for k, exp in cache.items() if float(exp) > now
)
if not self._nonce_caches[aid]:
del self._nonce_caches[aid]
agent_cache = self._nonce_caches.get(agent_id)
# Replay check
if agent_cache and nonce in agent_cache:
metrics_inc("dm_nonce_replay")
return False, "nonce replay detected"
# Per-agent capacity check
per_agent_max = self._nonce_per_agent_max()
if agent_cache and len(agent_cache) >= per_agent_max:
metrics_inc("dm_nonce_cache_full")
return False, "nonce cache at capacity"
# Global capacity is a soft memory bound. Trim the oldest nonce
# entries first so one busy agent cannot turn the global budget
# into a cross-agent availability choke point.
if self._total_nonce_count() >= self._nonce_cache_max_entries():
self._trim_global_nonce_budget(preferred_agent_id=agent_id)
expiry = max(now + self._nonce_ttl_seconds(), float(timestamp) + self._nonce_ttl_seconds())
if agent_cache is None:
agent_cache = OrderedDict()
self._nonce_caches[agent_id] = agent_cache
agent_cache[nonce] = expiry
agent_cache.move_to_end(nonce)
self._save()
return True, "ok"
def deposit(
self,
*,
sender_id: str,
raw_sender_id: str = "",
recipient_id: str = "",
ciphertext: str,
msg_id: str = "",
delivery_class: str,
recipient_token: str | None = None,
sender_seal: str = "",
relay_salt: str = "",
sender_token_hash: str = "",
payload_format: str = "dm1",
session_welcome: str = "",
) -> dict[str, Any]:
with self._lock:
self._refresh_from_shared_relay()
authority_sender = str(raw_sender_id or sender_id or "").strip()
sender_block_ref = self._sender_block_ref(
authority_sender,
scope=self._sender_block_scope(
recipient_id=recipient_id,
recipient_token=str(recipient_token or ""),
delivery_class=delivery_class,
),
)
blocked_refs = self._sender_block_refs(
authority_sender,
recipient_id=recipient_id,
recipient_token=str(recipient_token or ""),
delivery_class=delivery_class,
)
if recipient_id and any(ref in self._blocks.get(recipient_id, set()) for ref in blocked_refs):
metrics_inc("dm_drop_blocked")
return {"ok": False, "detail": "Recipient is not accepting your messages"}
if len(ciphertext) > int(self._settings().MESH_DM_MAX_MSG_BYTES):
metrics_inc("dm_drop_oversize")
return {
"ok": False,
"detail": f"Message too large ({len(ciphertext)} > {int(self._settings().MESH_DM_MAX_MSG_BYTES)})",
}
self._cleanup_expired()
if delivery_class == "request":
if not sender_token_hash:
return {"ok": False, "detail": "sender_token required for request delivery"}
mailbox_key = self._mailbox_key("requests", recipient_id)
elif delivery_class == "shared":
if not recipient_token:
metrics_inc("dm_claim_invalid")
return {"ok": False, "detail": "recipient_token required for shared delivery"}
mailbox_key = self._hashed_mailbox_token(recipient_token)
else:
return {"ok": False, "detail": "Unsupported delivery_class"}
if len(self._mailboxes[mailbox_key]) >= self._mailbox_limit_for_class(delivery_class):
metrics_inc("dm_drop_full")
return {"ok": False, "detail": "Recipient mailbox full"}
# Anti-spam: per-(sender, recipient) cap on unacked messages.
# A sender who already has the configured number of messages
# parked in this mailbox can't deposit more until the recipient
# pulls (acks) at least one. The same cap is re-enforced on
# inbound replication in ``accept_replica`` so this rule isn't
# bypassable by patching out the local check on a hostile
# sender's relay — see config.py
# MESH_DM_PENDING_PER_SENDER_LIMIT for the threat model.
per_sender_limit = self._per_sender_pending_limit()
pending = self._per_sender_pending_count(
mailbox_key=mailbox_key,
sender_block_ref=sender_block_ref,
)
if pending >= per_sender_limit:
metrics_inc("dm_drop_per_sender_cap")
return {
"ok": False,
"detail": (
f"Recipient already has {pending} unread message"
f"{'s' if pending != 1 else ''} from you. Wait for "
"them to read your messages before sending more."
),
}
if not msg_id:
msg_id = f"dm_{int(time.time() * 1000)}_{secrets.token_hex(6)}"
elif any(m.msg_id == msg_id for m in self._mailboxes[mailbox_key]):
return {"ok": True, "msg_id": msg_id}
relay_sender_id = (
f"sender_token:{sender_token_hash}"
if sender_token_hash
else sender_id
)
self._mailboxes[mailbox_key].append(
DMMessage(
sender_id=relay_sender_id,
ciphertext=ciphertext,
timestamp=time.time(),
msg_id=msg_id,
delivery_class=delivery_class,
sender_seal=sender_seal,
sender_block_ref=sender_block_ref,
payload_format=str(payload_format or "dm1"),
session_welcome=str(session_welcome or ""),
)
)
self._stats["messages_in_memory"] = sum(len(v) for v in self._mailboxes.values())
self._save()
# Cross-node mailbox replication: push the freshly-stored
# envelope to every authenticated relay peer so the recipient
# can log into ANY node and find their messages. The push is
# async (fire-and-forget thread) so deposit() returns
# immediately — slow Tor peers can't block the sender's UX.
# Each receiving peer re-enforces the per-sender cap on
# acceptance, so hostile relays can't widen the cap.
try:
envelope_for_push = self.envelope_for_replication(
mailbox_key=mailbox_key, msg_id=msg_id,
)
if envelope_for_push:
self._replicate_envelope_to_peers_async(
envelope=envelope_for_push,
)
except Exception:
metrics_inc("dm_replication_push_error")
return {"ok": True, "msg_id": msg_id}
def accept_replica(
self,
*,
envelope: dict[str, Any],
originating_peer_url: str = "",
) -> dict[str, Any]:
"""Receive a DM envelope replicated from a peer relay.
Cross-node mailbox replication entry point. When a sender's local
relay accepts a ``deposit`` and pushes the envelope to
``MESH_RELAY_PEERS`` (so the recipient can log into any peer
node and find their messages), each receiving peer calls
``accept_replica`` to ingest it.
The per-(sender, recipient) cap is re-enforced HERE. That's what
makes the rule a NETWORK rule rather than a client-side honor
system: a hostile sender who patches out the local ``deposit``
check still can't get a 3rd unacked message to spread, because
every honest peer enforces the same cap on inbound replicas.
Result: hostile relays can hold extras locally, but those extras
never reach any node a legitimate recipient is polling from.
Returns the same shape as ``deposit`` so the calling endpoint can
forward the result back to the originating peer.
"""
if not isinstance(envelope, dict):
return {"ok": False, "detail": "envelope must be an object"}
msg_id = str(envelope.get("msg_id", "") or "").strip()
mailbox_key = str(envelope.get("mailbox_key", "") or "").strip()
sender_block_ref = str(envelope.get("sender_block_ref", "") or "").strip()
ciphertext = str(envelope.get("ciphertext", "") or "")
if not msg_id or not mailbox_key or not sender_block_ref or not ciphertext:
return {"ok": False, "detail": "envelope missing required fields"}
with self._lock:
self._refresh_from_shared_relay()
self._cleanup_expired()
# Idempotent — if we already hold this exact msg_id, the
# replication round-tripped or a peer pushed the same
# envelope through multiple paths. Accept silently.
if any(m.msg_id == msg_id for m in self._mailboxes.get(mailbox_key, [])):
metrics_inc("dm_replica_duplicate")
return {"ok": True, "msg_id": msg_id, "duplicate": True}
# Same per-class cap as the deposit path — defense in depth
# against a peer that wraps a "deposit" as a "replica" to
# bypass the class limit.
delivery_class = str(envelope.get("delivery_class", "") or "")
if delivery_class in ("request", "shared", "self"):
class_limit = self._mailbox_limit_for_class(delivery_class)
else:
class_limit = self._shared_mailbox_limit()
if len(self._mailboxes.get(mailbox_key, [])) >= class_limit:
metrics_inc("dm_replica_drop_full")
return {"ok": False, "detail": "Recipient mailbox full"}
# THE network rule: per-(sender, recipient) anti-spam cap.
per_sender_limit = self._per_sender_pending_limit()
pending = self._per_sender_pending_count(
mailbox_key=mailbox_key,
sender_block_ref=sender_block_ref,
)
if pending >= per_sender_limit:
metrics_inc("dm_replica_drop_per_sender_cap")
# Returning a structured rejection — the sender's relay
# learns its envelope was rejected by an honest peer and
# can stop trying to push it.
return {
"ok": False,
"detail": (
"Per-sender cap reached on this relay; refusing replica"
),
"cap_violation": True,
"pending": pending,
"limit": per_sender_limit,
}
# Accept the replica into the local mailbox.
self._mailboxes[mailbox_key].append(
DMMessage(
sender_id=str(envelope.get("sender_id", "") or ""),
ciphertext=ciphertext,
timestamp=float(envelope.get("timestamp", time.time()) or time.time()),
msg_id=msg_id,
delivery_class=str(envelope.get("delivery_class", "shared") or "shared"),
sender_seal=str(envelope.get("sender_seal", "") or ""),
relay_salt=str(envelope.get("relay_salt", "") or ""),
sender_block_ref=sender_block_ref,
payload_format=str(envelope.get("payload_format", "dm1") or "dm1"),
session_welcome=str(envelope.get("session_welcome", "") or ""),
)
)
self._stats["messages_in_memory"] = sum(len(v) for v in self._mailboxes.values())
self._save()
metrics_inc("dm_replica_accepted")
return {"ok": True, "msg_id": msg_id}
def _replicate_envelope_to_peers_async(
self,
*,
envelope: dict[str, Any],
) -> None:
"""Push an outbound DM envelope to every authenticated relay peer.
Fire-and-forget: spawned in a background thread so ``deposit``
returns to the caller immediately. Per-peer errors are logged
and swallowed — the sender's UX must not block on slow Tor
peers, and a peer that's down today gets the next message
whenever it comes back. Inbound recipient polling from a healthy
peer keeps the system functional during peer failures.
Each peer is authed with the existing per-peer HMAC pattern
(#256) — same headers and key resolver gate-message replication
uses, so a hostile node that doesn't know any peer's HMAC key
can't impersonate a legitimate relay.
"""
import threading
def _do_push():
try:
import hashlib
import hmac
import requests as _requests
from services.mesh.mesh_crypto import (
normalize_peer_url,
resolve_peer_key_for_url,
)
from services.mesh.mesh_router import (
authenticated_push_peer_urls,
)
peers = authenticated_push_peer_urls()
if not peers:
return
payload = json.dumps(
{"envelope": envelope},
separators=(",", ":"),
ensure_ascii=False,
).encode("utf-8")
timeout = max(
1,
int(getattr(self._settings(), "MESH_RELAY_PUSH_TIMEOUT_S", 10) or 10),
)
for peer_url in peers:
try:
normalized = normalize_peer_url(peer_url)
headers = {"Content-Type": "application/json"}
peer_key = resolve_peer_key_for_url(normalized)
if peer_key:
headers["X-Peer-Url"] = normalized
headers["X-Peer-HMAC"] = hmac.new(
peer_key, payload, hashlib.sha256
).hexdigest()
url = f"{peer_url}/api/mesh/dm/replicate-envelope"
resp = _requests.post(
url, data=payload, timeout=timeout, headers=headers,
)
if resp.status_code == 200:
metrics_inc("dm_replication_push_ok")
else:
# 4xx including the structured cap_violation
# rejection from accept_replica — sender's
# relay learns and stops retrying this msg_id.
metrics_inc("dm_replication_push_rejected")
except Exception:
# Per-peer failure is non-fatal — log to metrics
# but don't break the loop. Other peers and a
# future retry can still propagate the envelope.
metrics_inc("dm_replication_push_error")
continue
except Exception:
# Outer guard — never let replication errors propagate
# back to the sender's deposit() caller.
metrics_inc("dm_replication_push_error")
thread = threading.Thread(
target=_do_push,
name="dm-replicate-push",
daemon=True,
)
thread.start()
def envelope_for_replication(
self,
*,
mailbox_key: str,
msg_id: str,
) -> dict[str, Any] | None:
"""Return the wire-form envelope for a stored message, suitable
for POSTing to a peer relay's replicate-envelope endpoint.
Returns ``None`` if the message isn't in the mailbox (already
acked, expired, never existed). The caller holds the
responsibility for transport security (Tor SOCKS for .onion
peers, per-peer HMAC) and for not leaking the envelope to
clearnet peers when private transport is required.
"""
with self._lock:
for m in self._mailboxes.get(mailbox_key, []):
if m.msg_id == msg_id:
return {
"msg_id": m.msg_id,
"mailbox_key": mailbox_key,
"sender_id": m.sender_id,
"sender_block_ref": m.sender_block_ref,
"sender_seal": m.sender_seal,
"ciphertext": m.ciphertext,
"timestamp": m.timestamp,
"delivery_class": m.delivery_class,
"relay_salt": m.relay_salt,
"payload_format": m.payload_format,
"session_welcome": m.session_welcome,
}
return None
def is_blocked(self, recipient_id: str, sender_id: str) -> bool:
with self._lock:
self._refresh_from_shared_relay()
if not recipient_id:
return False
blocked_refs = self._sender_block_refs(
sender_id,
recipient_id=recipient_id,
delivery_class="request",
)
return any(ref in self._blocks.get(recipient_id, set()) for ref in blocked_refs)
def _collect_from_keys(
self, keys: list[str], *, destructive: bool, limit: int = 0,
) -> tuple[list[dict[str, Any]], bool]:
messages: list[DMMessage] = []
seen: set[str] = set()
popped: dict[str, list[DMMessage]] = {}
for key in keys:
if destructive:
raw = self._mailboxes.pop(key, [])
popped[key] = raw
else:
raw = list(self._mailboxes.get(key, []))
for message in raw:
if message.msg_id in seen:
continue
seen.add(message.msg_id)
messages.append(message)
sorted_messages = sorted(messages, key=lambda item: item.timestamp)
has_more = False
if limit > 0 and len(sorted_messages) > limit:
has_more = True
kept = sorted_messages[:limit]
if destructive:
kept_ids = {m.msg_id for m in kept}
for key, original in popped.items():
remaining = [m for m in original if m.msg_id not in kept_ids]
if remaining:
self._mailboxes.setdefault(key, []).extend(remaining)
sorted_messages = kept
if destructive:
self._stats["messages_in_memory"] = sum(len(v) for v in self._mailboxes.values())
self._save()
result = [
{
"sender_id": message.sender_id,
"ciphertext": message.ciphertext,
"timestamp": message.timestamp,
"msg_id": message.msg_id,
"delivery_class": message.delivery_class,
"sender_seal": message.sender_seal,
"format": message.payload_format,
"session_welcome": message.session_welcome,
}
for message in sorted_messages
]
return result, has_more
def collect_claims(
self, agent_id: str, claims: list[dict[str, Any]], *, limit: int = 0,
) -> tuple[list[dict[str, Any]], bool]:
with self._lock:
self._refresh_from_shared_relay()
self._cleanup_expired()
keys: list[str] = []
for claim in claims[:32]:
keys.extend(self._mailbox_keys_for_claim(agent_id, claim))
return self._collect_from_keys(list(dict.fromkeys(keys)), destructive=True, limit=limit)
def count_claims(self, agent_id: str, claims: list[dict[str, Any]]) -> int:
with self._lock:
self._refresh_from_shared_relay()
self._cleanup_expired()
keys: list[str] = []
for claim in claims[:32]:
keys.extend(self._mailbox_keys_for_claim(agent_id, claim))
messages, _ = self._collect_from_keys(list(dict.fromkeys(keys)), destructive=False)
return len(messages)
def claim_message_ids(self, agent_id: str, claims: list[dict[str, Any]]) -> set[str]:
with self._lock:
self._refresh_from_shared_relay()
self._cleanup_expired()
keys: list[str] = []
for claim in claims[:32]:
keys.extend(self._mailbox_keys_for_claim(agent_id, claim))
messages, _ = self._collect_from_keys(list(dict.fromkeys(keys)), destructive=False)
return {
str(message.get("msg_id", "") or "")
for message in messages
if str(message.get("msg_id", "") or "")
}
def collect_legacy(
self, agent_id: str | None = None, agent_token: str | None = None, *, limit: int = 0,
) -> tuple[list[dict[str, Any]], bool]:
with self._lock:
self._refresh_from_shared_relay()
self._cleanup_expired()
if not agent_token:
return [], False
keys = [self._pepper_token(agent_token), agent_token]
return self._collect_from_keys(list(dict.fromkeys(keys)), destructive=True, limit=limit)
def count_legacy(self, agent_id: str | None = None, agent_token: str | None = None) -> int:
with self._lock:
self._refresh_from_shared_relay()
self._cleanup_expired()
if not agent_token:
return 0
keys = [self._pepper_token(agent_token), agent_token]
messages, _ = self._collect_from_keys(list(dict.fromkeys(keys)), destructive=False)
return len(messages)
def block(self, agent_id: str, blocked_id: str) -> None:
with self._lock:
self._refresh_from_shared_relay()
blocked_ref = self._canonical_blocked_id(
blocked_id,
scope=self._sender_block_scope(recipient_id=agent_id, delivery_class="request"),
)
if not blocked_ref:
return
self._blocks[agent_id].add(blocked_ref)
blocked_refs = {blocked_ref}
blocked_label = str(blocked_id or "").strip()
if blocked_label and not blocked_label.startswith("ref:"):
blocked_refs.add(self._legacy_sender_block_ref(blocked_label))
purge_keys = self._legacy_token_candidates(agent_id)
bound_request = self._bound_mailbox_key(agent_id, "requests")
bound_self = self._bound_mailbox_key(agent_id, "self")
if bound_request:
purge_keys.append(bound_request)
if bound_self:
purge_keys.append(bound_self)
purge_keys.extend(
[
self._mailbox_key("self", agent_id),
self._mailbox_key("requests", agent_id),
self._mailbox_key("self", agent_id, self._epoch_bucket() - 1),
self._mailbox_key("requests", agent_id, self._epoch_bucket() - 1),
]
)
for key in set(purge_keys):
if key in self._mailboxes:
self._mailboxes[key] = [
m for m in self._mailboxes[key] if self._message_block_ref(m) not in blocked_refs
]
self._stats["messages_in_memory"] = sum(len(v) for v in self._mailboxes.values())
self._save()
def unblock(self, agent_id: str, blocked_id: str) -> None:
with self._lock:
self._refresh_from_shared_relay()
blocked_ref = self._canonical_blocked_id(
blocked_id,
scope=self._sender_block_scope(recipient_id=agent_id, delivery_class="request"),
)
if not blocked_ref:
return
self._blocks[agent_id].discard(blocked_ref)
blocked_label = str(blocked_id or "").strip()
if blocked_label and not blocked_label.startswith("ref:"):
self._blocks[agent_id].discard(self._legacy_sender_block_ref(blocked_label))
self._save()
dm_relay = DMRelay()