Files
Shadowbroker/backend/services/mesh/mesh_dm_relay.py
T
BigBodyCobain 79b39e8985 DM mailbox: per-(sender, recipient) anti-spam cap + replication primitives
Foundation work for cross-node DM mailbox replication. Adds the network
rule that makes the replication safe to ship next, plus the primitives
the outbound replication PR will call.

The rule
--------
A single sender can have at most N UNACKED messages parked in a single
recipient's mailbox at any one time. Default N=2, tunable via
``MESH_DM_PENDING_PER_SENDER_LIMIT``. Once the recipient pulls (acks) a
message, the sender's quota for that (sender, recipient) pair frees up.

Network rule, not local rule
----------------------------
The cap is enforced TWICE:

  1. ``DMRelay.deposit(...)`` — local check on the sender's own node.
     Refuses to spool the (N+1)th message before it can be replicated.

  2. ``DMRelay.accept_replica(...)`` — replication-acceptance check on
     every receiving peer. Refuses to accept an inbound replica that
     would put the local mailbox over the cap.

The second half is what makes the rule a NETWORK rule. A hostile sender
could patch out the deposit check on their own relay and continue to
spool extras locally — but those extras can never propagate, because
every honest peer enforces the same cap on the way in. A recipient who
polls from honest peers therefore never sees more than N pending from
any one sender, regardless of how many spam attempts the hostile
sender's relay accepted.

New API surface on ``DMRelay``
------------------------------
  _per_sender_pending_limit()       — reads MESH_DM_PENDING_PER_SENDER_LIMIT
  _per_sender_pending_count(...)    — counts unacked from a sender for a mailbox
  accept_replica(envelope=...)      — peer-push receive entry point
  envelope_for_replication(...)     — helper to extract a wire-form envelope

``accept_replica`` is idempotent on duplicate ``msg_id`` (replication
round-trips and multi-path delivery don't double-spool).

``envelope_for_replication`` exposes the exact shape ``accept_replica``
expects, so the follow-up PR (outbound replication wiring) just has to
fetch the envelope and POST it to authenticated peer URLs with the
existing per-peer HMAC pattern from #256.

Why this is PR-1 of two
-----------------------
The full cross-node mailbox replication needs three pieces:

  A. cap enforcement on deposit (in this PR)
  B. cap enforcement on replica acceptance (in this PR)
  C. outbound: push envelope to MESH_RELAY_PEERS after deposit (NEXT PR)

(A) + (B) shipped together close the cap-bypass attack surface BEFORE
(C) introduces the actual cross-node propagation. Shipping them in the
other order would briefly let extras propagate during the window between
"outbound push lands" and "accept_replica cap lands."

Tests
-----
backend/tests/test_dm_relay_per_sender_cap.py — 14 tests:

  TestDepositCap:
    - first 2 deposits succeed (UX baseline)
    - 3rd from same sender rejected with friendly message
    - different senders have independent quotas
    - different recipients have independent quotas
    - ack frees the quota (after recipient pulls, sender can deposit again)
    - cap is env-tunable

  TestAcceptReplicaCap:
    - replica accepted under cap
    - idempotent on duplicate msg_id (no double-spool, no rejection)
    - rejected at cap with structured ``cap_violation`` marker so
      sender's relay can stop retrying
    - per-sender, not per-mailbox: different sender_block_ref passes
      even when another sender at the same mailbox is capped
    - malformed envelope shapes rejected without crash

  TestEnvelopeForReplication:
    - returns the envelope for stored messages
    - returns None for unknown msg_id
    - round-trips through accept_replica end-to-end (proves the wire
      shape matches across the two sides)
2026-05-22 19:18:01 -06:00

1900 lines
81 KiB
Python

"""Metadata-minimized DM relay for request and shared mailboxes.
This relay never decrypts application payloads. In secure mode it keeps
pending ciphertext in memory only and persists just the minimum metadata
needed for continuity: accepted DH bundles, block lists, witness data,
and nonce replay windows.
"""
from __future__ import annotations
import atexit
import hashlib
import json
import logging
import os
import secrets
import threading
import time
from collections import OrderedDict, defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from services.config import get_settings
from services.mesh.mesh_metrics import increment as metrics_inc
from services.mesh.mesh_wormhole_prekey import (
_validate_bundle_record,
transparency_fingerprint_for_bundle_record,
)
from services.mesh.mesh_secure_storage import read_secure_json, write_secure_json
TTL_SECONDS = 3600
EPOCH_SECONDS = 6 * 60 * 60
BACKEND_DIR = Path(__file__).resolve().parents[2]
DEFAULT_DATA_DIR = BACKEND_DIR / "data"
DEFAULT_RELAY_FILE = DEFAULT_DATA_DIR / "dm_relay.json"
DATA_DIR = DEFAULT_DATA_DIR
RELAY_FILE = DEFAULT_RELAY_FILE
logger = logging.getLogger(__name__)
def _stable_json(value: Any) -> str:
return json.dumps(value, sort_keys=True, separators=(",", ":"))
def _get_token_pepper() -> str:
"""Read token pepper lazily so auto-generated values from startup audit take effect."""
pepper = os.environ.get("MESH_DM_TOKEN_PEPPER", "").strip()
if not pepper:
try:
from services.config import get_settings
from services.env_check import _ensure_dm_token_pepper
pepper = _ensure_dm_token_pepper(get_settings())
except Exception:
pepper = os.environ.get("MESH_DM_TOKEN_PEPPER", "").strip()
if not pepper:
raise RuntimeError("MESH_DM_TOKEN_PEPPER is unavailable at runtime")
return pepper
@dataclass
class DMMessage:
sender_id: str
ciphertext: str
timestamp: float
msg_id: str
delivery_class: str
sender_seal: str = ""
relay_salt: str = ""
sender_block_ref: str = ""
payload_format: str = "dm1"
session_welcome: str = ""
class DMRelay:
"""Relay for encrypted request/shared mailboxes."""
def __init__(self) -> None:
self._lock = threading.RLock()
self._mailboxes: dict[str, list[DMMessage]] = defaultdict(list)
self._dh_keys: dict[str, dict[str, Any]] = {}
self._prekey_bundles: dict[str, dict[str, Any]] = {}
self._mailbox_bindings: dict[str, dict[str, Any]] = defaultdict(dict)
self._witnesses: dict[str, list[dict[str, Any]]] = defaultdict(list)
self._blocks: dict[str, set[str]] = defaultdict(set)
self._nonce_caches: dict[str, OrderedDict[str, float]] = {}
"""Per-agent nonce replay caches — keyed by agent_id, values are OrderedDicts of nonce→expiry."""
self._prekey_lookup_aliases: dict[str, dict[str, Any]] = {}
"""Invite-scoped lookup handle → agent_id for prekey bundle fetch without stable identity."""
self._stats: dict[str, int] = {"messages_in_memory": 0}
self._dirty = False
self._save_timer: threading.Timer | None = None
self._last_persist_error = ""
self._SAVE_INTERVAL = 5.0
atexit.register(self._flush)
self._load()
def _settings(self):
return get_settings()
def _persist_spool_enabled(self) -> bool:
return bool(self._settings().MESH_DM_PERSIST_SPOOL)
def _relay_file(self) -> Path:
# Unit tests frequently monkeypatch the module-level relay file so each
# relay instance stays isolated from the shared runtime spool path.
module_override = Path(RELAY_FILE)
if module_override != DEFAULT_RELAY_FILE:
return module_override.expanduser().resolve()
override = str(getattr(self._settings(), "MESH_DM_RELAY_FILE_PATH", "") or "").strip()
if override:
override_path = Path(override).expanduser()
if not override_path.is_absolute():
override_path = BACKEND_DIR / override_path
return override_path.resolve()
return RELAY_FILE
def _relay_data_dir(self) -> Path:
return self._relay_file().parent
def _auto_reload_enabled(self) -> bool:
if Path(RELAY_FILE) != DEFAULT_RELAY_FILE:
return False
return bool(getattr(self._settings(), "MESH_DM_RELAY_AUTO_RELOAD", False))
def _refresh_from_shared_relay(self) -> None:
if self._auto_reload_enabled():
self._reload_snapshot_from_shared_relay()
def _reload_snapshot_from_shared_relay(self) -> None:
relay_file = self._relay_file()
fresh_mailboxes: defaultdict[str, list[DMMessage]] = defaultdict(list)
fresh_dh_keys: dict[str, dict[str, Any]] = {}
fresh_prekey_bundles: dict[str, dict[str, Any]] = {}
fresh_mailbox_bindings: defaultdict[str, dict[str, Any]] = defaultdict(dict)
fresh_witnesses: defaultdict[str, list[dict[str, Any]]] = defaultdict(list)
fresh_blocks: defaultdict[str, set[str]] = defaultdict(set)
fresh_nonce_caches: dict[str, OrderedDict[str, float]] = {}
fresh_prekey_lookup_aliases: dict[str, dict[str, Any]] = {}
fresh_stats: dict[str, int] = {"messages_in_memory": 0}
current_mailboxes = defaultdict(list, {k: list(v) for k, v in self._mailboxes.items()})
current_bindings = defaultdict(
dict,
{
str(agent_id): {
str(kind): dict(entry)
for kind, entry in bindings.items()
if isinstance(entry, dict)
}
for agent_id, bindings in self._mailbox_bindings.items()
if isinstance(bindings, dict)
},
)
if not relay_file.exists():
if not self._persist_spool_enabled():
fresh_mailboxes = current_mailboxes
if not self._metadata_persist_enabled():
fresh_mailbox_bindings = current_bindings
self._mailboxes = fresh_mailboxes
self._dh_keys = fresh_dh_keys
self._prekey_bundles = fresh_prekey_bundles
self._mailbox_bindings = fresh_mailbox_bindings
self._witnesses = fresh_witnesses
self._blocks = fresh_blocks
self._nonce_caches = fresh_nonce_caches
self._prekey_lookup_aliases = fresh_prekey_lookup_aliases
self._stats = fresh_stats
return
try:
data = read_secure_json(relay_file, lambda: {})
except Exception:
return
if self._persist_spool_enabled():
mailboxes = data.get("mailboxes", {})
if isinstance(mailboxes, dict):
for key, items in mailboxes.items():
if not isinstance(items, list):
continue
restored: list[DMMessage] = []
for item in items:
try:
restored.append(
DMMessage(
sender_id=str(item.get("sender_id", "")),
ciphertext=str(item.get("ciphertext", "")),
timestamp=float(item.get("timestamp", 0)),
msg_id=str(item.get("msg_id", "")),
delivery_class=str(item.get("delivery_class", "shared")),
sender_seal=str(item.get("sender_seal", "")),
relay_salt=str(item.get("relay_salt", "") or ""),
sender_block_ref=str(item.get("sender_block_ref", "") or ""),
payload_format=str(item.get("payload_format", item.get("format", "dm1")) or "dm1"),
session_welcome=str(item.get("session_welcome", "") or ""),
)
)
except Exception:
continue
for message in restored:
if not message.sender_block_ref:
message.sender_block_ref = self._message_block_ref(message)
if restored:
fresh_mailboxes[str(key)] = restored
else:
if not self._persist_spool_enabled():
fresh_mailboxes = current_mailboxes
dh_keys = data.get("dh_keys", {})
if isinstance(dh_keys, dict):
fresh_dh_keys = {str(k): dict(v) for k, v in dh_keys.items() if isinstance(v, dict)}
prekey_bundles = data.get("prekey_bundles", {})
if isinstance(prekey_bundles, dict):
fresh_prekey_bundles = {
str(k): dict(v) for k, v in prekey_bundles.items() if isinstance(v, dict)
}
prekey_lookup_aliases = data.get("prekey_lookup_aliases", {})
if isinstance(prekey_lookup_aliases, dict):
for key, value in prekey_lookup_aliases.items():
handle = str(key or "").strip()
record = self._coerce_prekey_lookup_alias_record(value)
if handle and record:
fresh_prekey_lookup_aliases[handle] = record
now = time.time()
mailbox_bindings = data.get("mailbox_bindings", {})
if isinstance(mailbox_bindings, dict) and self._metadata_persist_enabled():
for agent_id, bindings in mailbox_bindings.items():
if not isinstance(bindings, dict):
continue
restored_agent: dict[str, dict[str, Any]] = {}
for kind, entry in bindings.items():
token_hash = ""
last_used = now
if isinstance(entry, dict):
token_hash = str(entry.get("token_hash", "") or "").strip()
last_used = float(entry.get("last_used", now) or now)
else:
token_hash = str(entry or "").strip()
if token_hash:
normalized = self._coerce_mailbox_binding_entry(
{
"token_hash": token_hash,
"bound_at": float(entry.get("bound_at", last_used) or last_used)
if isinstance(entry, dict)
else last_used,
"last_used": last_used,
"expires_at": float(entry.get("expires_at", 0) or 0)
if isinstance(entry, dict)
else 0,
},
now=now,
)
if normalized:
restored_agent[str(kind)] = normalized
if restored_agent:
fresh_mailbox_bindings[str(agent_id)] = restored_agent
elif not self._metadata_persist_enabled():
fresh_mailbox_bindings = current_bindings
witnesses = data.get("witnesses", {})
if isinstance(witnesses, dict):
fresh_witnesses = defaultdict(
list,
{str(k): list(v) for k, v in witnesses.items() if isinstance(v, list)},
)
blocks = data.get("blocks", {})
if isinstance(blocks, dict):
for key, values in blocks.items():
if isinstance(values, list):
fresh_blocks[str(key)] = {
self._canonical_blocked_id(str(v))
for v in values
if str(v or "").strip()
}
nonce_caches = data.get("nonce_caches", {})
if isinstance(nonce_caches, dict) and nonce_caches:
for aid, entries in nonce_caches.items():
if not isinstance(entries, dict):
continue
restored = sorted(
((str(k), float(v)) for k, v in entries.items() if float(v) > now),
key=lambda item: item[1],
)
if restored:
fresh_nonce_caches[str(aid)] = OrderedDict(restored)
else:
nonce_cache = data.get("nonce_cache", {})
if isinstance(nonce_cache, dict):
for composite_key, expiry in nonce_cache.items():
if float(expiry) <= now:
continue
parts = str(composite_key).split(":", 1)
if len(parts) == 2:
aid, nonce_val = parts
if aid not in fresh_nonce_caches:
fresh_nonce_caches[aid] = OrderedDict()
fresh_nonce_caches[aid][nonce_val] = float(expiry)
stats = data.get("stats", {})
if isinstance(stats, dict):
fresh_stats = {str(k): int(v) for k, v in stats.items() if isinstance(v, (int, float))}
self._mailboxes = fresh_mailboxes
self._dh_keys = fresh_dh_keys
self._prekey_bundles = fresh_prekey_bundles
self._mailbox_bindings = fresh_mailbox_bindings
self._witnesses = fresh_witnesses
self._blocks = fresh_blocks
self._nonce_caches = fresh_nonce_caches
self._prekey_lookup_aliases = fresh_prekey_lookup_aliases
self._stats = fresh_stats
self._stats["messages_in_memory"] = sum(len(v) for v in self._mailboxes.values())
if self._prune_stale_metadata():
self._dirty = True
def _request_mailbox_limit(self) -> int:
return max(1, int(self._settings().MESH_DM_REQUEST_MAILBOX_LIMIT))
def _shared_mailbox_limit(self) -> int:
return max(1, int(self._settings().MESH_DM_SHARED_MAILBOX_LIMIT))
def _self_mailbox_limit(self) -> int:
return max(1, int(self._settings().MESH_DM_SELF_MAILBOX_LIMIT))
def _per_sender_pending_limit(self) -> int:
"""Anti-spam cap on UNACKED messages a single sender can have parked
in a single recipient mailbox at any one time. See ``config.py``
``MESH_DM_PENDING_PER_SENDER_LIMIT`` for the threat model — this
rule is enforced both at ``deposit`` (local) and at
``accept_replica`` (peer push acceptance), making it a network
rule rather than a client-side honor system."""
try:
limit = int(getattr(self._settings(), "MESH_DM_PENDING_PER_SENDER_LIMIT", 2) or 2)
except (TypeError, ValueError):
limit = 2
return max(1, limit)
def _per_sender_pending_count(
self,
*,
mailbox_key: str,
sender_block_ref: str,
) -> int:
"""Count UNACKED messages from ``sender_block_ref`` currently parked
in ``mailbox_key``. Caller already holds ``self._lock``.
Messages that have been claimed/acked are removed from the mailbox
list (see ``claim_message_ids``), so anything still here is by
definition unacked. We count by exact ``sender_block_ref`` match
— that's the per-pair sender identity used for blocking too, so
the cap is naturally per-(sender, recipient).
"""
if not mailbox_key or not sender_block_ref:
return 0
messages = self._mailboxes.get(mailbox_key, [])
return sum(1 for m in messages if m.sender_block_ref == sender_block_ref)
def _nonce_ttl_seconds(self) -> int:
return max(30, int(self._settings().MESH_DM_NONCE_TTL_S))
def _nonce_cache_max_entries(self) -> int:
return max(1, int(getattr(self._settings(), "MESH_DM_NONCE_CACHE_MAX", 4096)))
def _nonce_per_agent_max(self) -> int:
return max(1, int(getattr(self._settings(), "MESH_DM_NONCE_PER_AGENT_MAX", 256)))
def _dm_key_ttl_seconds(self) -> int:
return max(1, int(getattr(self._settings(), "MESH_DM_KEY_TTL_DAYS", 30) or 30)) * 86400
def _prekey_lookup_alias_ttl_seconds(self) -> int:
return max(
1,
int(getattr(self._settings(), "MESH_DM_PREKEY_LOOKUP_ALIAS_TTL_DAYS", 14) or 14),
) * 86400
def _witness_ttl_seconds(self) -> int:
return max(1, int(getattr(self._settings(), "MESH_DM_WITNESS_TTL_DAYS", 14) or 14)) * 86400
def _mailbox_binding_ttl_seconds(self) -> int:
return max(1, int(getattr(self._settings(), "MESH_DM_BINDING_TTL_DAYS", 3) or 3)) * 86400
def _mailbox_binding_idle_ttl_seconds(self) -> int:
return min(self._mailbox_binding_ttl_seconds(), 12 * 60 * 60)
def _mailbox_binding_refresh_after_seconds(self) -> int:
return max(15 * 60, min(self._mailbox_binding_ttl_seconds(), 12 * 60 * 60))
def _mailbox_binding_expires_at(self, entry: dict[str, Any]) -> float:
bound_at = float(entry.get("bound_at", 0) or 0)
last_used = float(entry.get("last_used", bound_at) or bound_at)
if bound_at <= 0:
return 0.0
return min(
bound_at + self._mailbox_binding_ttl_seconds(),
last_used + self._mailbox_binding_idle_ttl_seconds(),
)
def _coerce_mailbox_binding_entry(self, entry: Any, *, now: float | None = None) -> dict[str, Any]:
current = time.time() if now is None else float(now)
token_hash = ""
bound_at = current
last_used = current
if isinstance(entry, dict):
token_hash = str(entry.get("token_hash", "") or "").strip()
bound_at = float(entry.get("bound_at", entry.get("last_used", current)) or current)
last_used = float(entry.get("last_used", bound_at) or bound_at)
else:
token_hash = str(entry or "").strip()
if not token_hash:
return {}
normalized = {
"token_hash": token_hash,
"bound_at": bound_at,
"last_used": last_used,
}
normalized["expires_at"] = self._mailbox_binding_expires_at(normalized)
return normalized
def _alias_updated_at_for_agent(self, agent_id: str) -> float:
stored = self._prekey_bundles.get(str(agent_id or "").strip(), {})
if isinstance(stored, dict):
return float(stored.get("updated_at", stored.get("timestamp", time.time())) or time.time())
return float(time.time())
def _make_prekey_lookup_alias_record(
self,
agent_id: str,
*,
updated_at: float | None = None,
expires_at: int = 0,
max_uses: int = 0,
use_count: int = 0,
last_used_at: float = 0,
) -> dict[str, Any]:
aid = str(agent_id or "").strip()
if not aid:
return {}
if updated_at is None:
updated_at = self._alias_updated_at_for_agent(aid)
return {
"agent_id": aid,
"updated_at": float(updated_at or self._alias_updated_at_for_agent(aid)),
"expires_at": max(0, int(expires_at or 0)),
"max_uses": max(0, int(max_uses or 0)),
"use_count": max(0, int(use_count or 0)),
"last_used_at": float(last_used_at or 0),
}
def _coerce_prekey_lookup_alias_record(self, value: Any) -> dict[str, Any]:
if isinstance(value, dict):
aid = str(value.get("agent_id", "") or "").strip()
if not aid:
return {}
updated_at = float(
value.get("updated_at", value.get("last_used", self._alias_updated_at_for_agent(aid)))
or self._alias_updated_at_for_agent(aid)
)
return self._make_prekey_lookup_alias_record(
aid,
updated_at=updated_at,
expires_at=int(value.get("expires_at", 0) or 0),
max_uses=int(value.get("max_uses", 0) or 0),
use_count=int(value.get("use_count", value.get("uses", 0)) or 0),
last_used_at=float(value.get("last_used_at", value.get("last_used", 0)) or 0),
)
aid = str(value or "").strip()
if not aid:
return {}
return self._make_prekey_lookup_alias_record(aid)
def _resolve_prekey_lookup_alias(self, lookup_token: str) -> str:
handle = str(lookup_token or "").strip()
if not handle:
return ""
record = self._coerce_prekey_lookup_alias_record(self._prekey_lookup_aliases.get(handle, {}))
if not record:
return ""
now = time.time()
expires_at = int(record.get("expires_at", 0) or 0)
max_uses = int(record.get("max_uses", 0) or 0)
use_count = int(record.get("use_count", 0) or 0)
if (expires_at > 0 and now > expires_at) or (max_uses > 0 and use_count >= max_uses):
self._prekey_lookup_aliases.pop(handle, None)
self._save()
return ""
updated = self._make_prekey_lookup_alias_record(
str(record.get("agent_id", "") or "").strip(),
updated_at=float(record.get("updated_at", self._alias_updated_at_for_agent(str(record.get("agent_id", "") or "").strip())) or now),
expires_at=expires_at,
max_uses=max_uses,
use_count=use_count + 1,
last_used_at=now,
)
self._prekey_lookup_aliases[handle] = updated
self._save()
try:
from services.mesh.mesh_wormhole_identity import record_prekey_lookup_handle_use
record_prekey_lookup_handle_use(handle, now=int(now))
except Exception:
pass
return str(updated.get("agent_id", "") or "").strip()
def _pepper_token(self, token: str) -> str:
material = token
pepper = _get_token_pepper()
if pepper:
material = f"{pepper}|{token}"
return hashlib.sha256(material.encode("utf-8")).hexdigest()
def _legacy_sender_block_ref(self, sender_id: str) -> str:
sender = str(sender_id or "").strip()
if not sender:
return ""
return "ref:" + self._pepper_token(f"block|{sender}")
def _sender_block_scope(
self,
*,
recipient_id: str = "",
recipient_token: str = "",
delivery_class: str = "",
) -> str:
recipient = str(recipient_id or "").strip()
if recipient:
return f"recipient|{recipient}"
token = str(recipient_token or "").strip()
if token and str(delivery_class or "").strip().lower() == "shared":
return f"shared|{self._hashed_mailbox_token(token)}"
return ""
def _sender_block_ref(self, sender_id: str, *, scope: str = "") -> str:
sender = str(sender_id or "").strip()
if not sender:
return ""
material = f"block|{scope}|{sender}" if scope else f"block|{sender}"
return "ref:" + self._pepper_token(material)
def _sender_block_refs(
self,
sender_id: str,
*,
recipient_id: str = "",
recipient_token: str = "",
delivery_class: str = "",
) -> set[str]:
refs: set[str] = set()
legacy = self._legacy_sender_block_ref(sender_id)
if legacy:
refs.add(legacy)
scoped = self._sender_block_ref(
sender_id,
scope=self._sender_block_scope(
recipient_id=recipient_id,
recipient_token=recipient_token,
delivery_class=delivery_class,
),
)
if scoped:
refs.add(scoped)
return refs
def _canonical_blocked_id(self, blocked_id: str, *, scope: str = "") -> str:
blocked = str(blocked_id or "").strip()
if not blocked:
return ""
if blocked.startswith("ref:"):
return blocked
return self._sender_block_ref(blocked, scope=scope)
def _message_block_ref(self, message: DMMessage) -> str:
block_ref = str(getattr(message, "sender_block_ref", "") or "").strip()
if block_ref:
return block_ref
sender_id = str(message.sender_id or "").strip()
if not sender_id or sender_id.startswith("sealed:") or sender_id.startswith("sender_token:"):
return ""
return self._legacy_sender_block_ref(sender_id)
def _mailbox_key(self, mailbox_type: str, mailbox_value: str, epoch: int | None = None) -> str:
if mailbox_type in {"self", "requests"}:
bucket = self._epoch_bucket() if epoch is None else int(epoch)
identifier = f"{mailbox_type}|{bucket}|{mailbox_value}"
else:
identifier = f"{mailbox_type}|{mailbox_value}"
return self._pepper_token(identifier)
def _hashed_mailbox_token(self, token: str) -> str:
return hashlib.sha256(str(token or "").encode("utf-8")).hexdigest()
def _remember_mailbox_binding(self, agent_id: str, mailbox_type: str, token: str) -> str:
if self._prune_stale_mailbox_bindings():
self._save()
now = time.time()
agent_key = str(agent_id or "").strip()
mailbox_key = str(mailbox_type or "").strip().lower()
token_hash = self._hashed_mailbox_token(token)
current = self._coerce_mailbox_binding_entry(
self._mailbox_bindings.get(agent_key, {}).get(mailbox_key, {}),
now=now,
)
refreshed = {
"token_hash": token_hash,
"bound_at": now,
"last_used": now,
}
if current and str(current.get("token_hash", "") or "") == token_hash:
refreshed["bound_at"] = float(current.get("bound_at", now) or now)
if (now - refreshed["bound_at"]) >= self._mailbox_binding_refresh_after_seconds():
refreshed["bound_at"] = now
refreshed["expires_at"] = self._mailbox_binding_expires_at(refreshed)
self._mailbox_bindings[agent_key][mailbox_key] = refreshed
self._save()
return token_hash
def _bound_mailbox_key(self, agent_id: str, mailbox_type: str) -> str:
if self._prune_stale_mailbox_bindings():
self._save()
agent_key = str(agent_id or "").strip()
mailbox_key = str(mailbox_type or "").strip().lower()
entry = self._mailbox_bindings.get(agent_key, {}).get(
mailbox_key,
{},
)
normalized = self._coerce_mailbox_binding_entry(entry)
if normalized and normalized != entry:
self._mailbox_bindings[agent_key][mailbox_key] = normalized
self._save()
return str(normalized.get("token_hash", "") or "")
def _mailbox_keys_for_claim(self, agent_id: str, claim: dict[str, Any]) -> list[str]:
claim_type = str(claim.get("type", "")).strip().lower()
if claim_type == "shared":
token = str(claim.get("token", "")).strip()
if not token:
metrics_inc("dm_claim_invalid")
return []
return [self._hashed_mailbox_token(token)]
if claim_type == "requests":
token = str(claim.get("token", "")).strip()
if token:
previous_bound = self._bound_mailbox_key(agent_id, "requests")
bound_key = self._remember_mailbox_binding(agent_id, "requests", token)
epoch = self._epoch_bucket()
return [
key
for key in [
previous_bound,
bound_key,
self._mailbox_key("requests", agent_id, epoch),
self._mailbox_key("requests", agent_id, epoch - 1),
]
if key
]
metrics_inc("dm_claim_invalid")
return []
if claim_type == "self":
token = str(claim.get("token", "")).strip()
if token:
previous_bound = self._bound_mailbox_key(agent_id, "self")
bound_key = self._remember_mailbox_binding(agent_id, "self", token)
epoch = self._epoch_bucket()
return [
key
for key in [
previous_bound,
bound_key,
self._mailbox_key("self", agent_id, epoch),
self._mailbox_key("self", agent_id, epoch - 1),
]
if key
]
metrics_inc("dm_claim_invalid")
return []
metrics_inc("dm_claim_invalid")
return []
def mailbox_key_for_delivery(
self,
*,
recipient_id: str,
delivery_class: str,
recipient_token: str | None = None,
) -> str:
with self._lock:
self._refresh_from_shared_relay()
delivery_class = str(delivery_class or "").strip().lower()
if delivery_class == "request":
bound_key = self._bound_mailbox_key(recipient_id, "requests")
if bound_key:
return bound_key
return self._mailbox_key("requests", str(recipient_id or "").strip())
if delivery_class == "shared":
token = str(recipient_token or "").strip()
if not token:
raise ValueError("recipient_token required for shared delivery")
return self._hashed_mailbox_token(token)
raise ValueError("Unsupported delivery_class")
def claim_mailbox_keys(self, agent_id: str, claims: list[dict[str, Any]]) -> list[str]:
with self._lock:
self._refresh_from_shared_relay()
if self._prune_stale_mailbox_bindings():
self._save()
keys: list[str] = []
for claim in claims[:32]:
keys.extend(self._mailbox_keys_for_claim(agent_id, claim))
return list(dict.fromkeys(keys))
def _legacy_mailbox_token(self, agent_id: str, epoch: int) -> str:
raw = f"sb_dm|{epoch}|{agent_id}".encode("utf-8")
return hashlib.sha256(raw).hexdigest()
def _legacy_token_candidates(self, agent_id: str) -> list[str]:
epoch = self._epoch_bucket()
raw = [self._legacy_mailbox_token(agent_id, epoch), self._legacy_mailbox_token(agent_id, epoch - 1)]
peppered = [self._pepper_token(token) for token in raw]
return list(dict.fromkeys(peppered + raw))
def _save(self) -> None:
"""Mark dirty and schedule a coalesced disk write."""
self._dirty = True
relay_file = self._relay_file()
if self._auto_reload_enabled() or not relay_file.exists() or self._persist_failures_are_fatal():
self._flush()
return
with self._lock:
if self._save_timer is None or not self._save_timer.is_alive():
self._save_timer = threading.Timer(self._SAVE_INTERVAL, self._flush)
self._save_timer.daemon = True
self._save_timer.start()
def _persist_failures_are_fatal(self) -> bool:
return bool(os.environ.get("PYTEST_CURRENT_TEST", "").strip())
def _record_persist_failure(self, operation: str, exc: Exception) -> None:
self._last_persist_error = f"{operation}:{type(exc).__name__}:{exc}"
metrics_inc("dm_relay_persist_failure")
logger.exception("dm relay %s failed for %s", operation, self._relay_file())
def _prune_stale_metadata(self) -> bool:
"""Remove expired relay metadata that should not outlive its retention window."""
now = time.time()
key_ttl = self._dm_key_ttl_seconds()
alias_ttl = self._prekey_lookup_alias_ttl_seconds()
witness_ttl = self._witness_ttl_seconds()
changed = False
stale_keys = [
aid for aid, entry in self._dh_keys.items()
if (now - float(entry.get("timestamp", 0) or 0)) > key_ttl
]
for aid in stale_keys:
del self._dh_keys[aid]
changed = True
stale_bundles = [
aid for aid, entry in self._prekey_bundles.items()
if (now - float(entry.get("updated_at", entry.get("timestamp", 0)) or 0)) > key_ttl
]
for aid in stale_bundles:
del self._prekey_bundles[aid]
changed = True
stale_aliases: list[str] = []
for alias, value in list(self._prekey_lookup_aliases.items()):
record = self._coerce_prekey_lookup_alias_record(value)
if not record:
stale_aliases.append(alias)
continue
if self._prekey_lookup_aliases.get(alias) != record:
self._prekey_lookup_aliases[alias] = record
changed = True
target = str(record.get("agent_id", "") or "").strip()
updated_at = float(record.get("updated_at", self._alias_updated_at_for_agent(target)) or 0)
expires_at = int(record.get("expires_at", 0) or 0)
max_uses = int(record.get("max_uses", 0) or 0)
use_count = int(record.get("use_count", 0) or 0)
if (
not target
or target not in self._prekey_bundles
or (now - updated_at) > alias_ttl
or (expires_at > 0 and now > float(expires_at))
or (max_uses > 0 and use_count >= max_uses)
):
stale_aliases.append(alias)
for alias in stale_aliases:
del self._prekey_lookup_aliases[alias]
changed = True
for target_id in list(self._witnesses):
fresh = [
witness
for witness in self._witnesses.get(target_id, [])
if (now - float(witness.get("timestamp", 0) or 0)) <= witness_ttl
]
if len(fresh) != len(self._witnesses.get(target_id, [])):
changed = True
if fresh:
self._witnesses[target_id] = fresh
else:
del self._witnesses[target_id]
if self._prune_stale_mailbox_bindings(now=now):
changed = True
return changed
def _prune_stale_mailbox_bindings(self, *, now: float | None = None) -> bool:
current = time.time() if now is None else now
changed = False
stale_agents: list[str] = []
for agent_id, kinds in self._mailbox_bindings.items():
normalized_updates: dict[str, dict[str, Any]] = {}
expired_kinds = [
k
for k, v in kinds.items()
if not self._coerce_mailbox_binding_entry(v, now=current)
or current > self._mailbox_binding_expires_at(
self._coerce_mailbox_binding_entry(v, now=current)
)
]
for kind, entry in list(kinds.items()):
normalized = self._coerce_mailbox_binding_entry(entry, now=current)
if normalized and normalized != entry:
normalized_updates[kind] = normalized
for kind, normalized in normalized_updates.items():
kinds[kind] = normalized
changed = True
for k in expired_kinds:
del kinds[k]
changed = True
if not kinds:
stale_agents.append(agent_id)
for agent_id in stale_agents:
del self._mailbox_bindings[agent_id]
changed = True
return changed
def _metadata_persist_enabled(self) -> bool:
settings = self._settings()
return bool(getattr(settings, "MESH_DM_METADATA_PERSIST", False)) and bool(
getattr(settings, "MESH_DM_METADATA_PERSIST_ACKNOWLEDGE", False)
)
def _flush(self) -> None:
"""Actually write to disk (called by timer or atexit)."""
if not self._dirty:
return
try:
self._prune_stale_metadata()
relay_file = self._relay_file()
self._relay_data_dir().mkdir(parents=True, exist_ok=True)
payload: dict[str, Any] = {
"saved_at": int(time.time()),
"dh_keys": self._dh_keys,
"prekey_bundles": self._prekey_bundles,
"prekey_lookup_aliases": self._prekey_lookup_aliases,
"witnesses": self._witnesses,
"blocks": {k: sorted(v) for k, v in self._blocks.items()},
"nonce_caches": {aid: dict(c) for aid, c in self._nonce_caches.items()},
"stats": self._stats,
}
if self._metadata_persist_enabled():
payload["mailbox_bindings"] = {
agent_id: {
mailbox_type: {
"token_hash": str(entry.get("token_hash", "") or "").strip(),
"bound_at": float(entry.get("bound_at", 0) or 0),
"last_used": float(entry.get("last_used", 0) or 0),
"expires_at": float(entry.get("expires_at", 0) or 0),
}
for mailbox_type, entry in bindings.items()
if isinstance(entry, dict) and str(entry.get("token_hash", "") or "").strip()
}
for agent_id, bindings in self._mailbox_bindings.items()
if isinstance(bindings, dict)
}
if self._persist_spool_enabled():
payload["mailboxes"] = {
key: [m.__dict__ for m in msgs] for key, msgs in self._mailboxes.items()
}
write_secure_json(relay_file, payload)
self._dirty = False
self._last_persist_error = ""
except Exception as exc:
self._record_persist_failure("flush", exc)
if self._persist_failures_are_fatal():
raise
def _load(self) -> None:
relay_file = self._relay_file()
if not relay_file.exists():
return
try:
data = read_secure_json(relay_file, lambda: {})
except Exception:
return
if self._persist_spool_enabled():
mailboxes = data.get("mailboxes", {})
if isinstance(mailboxes, dict):
for key, items in mailboxes.items():
if not isinstance(items, list):
continue
restored: list[DMMessage] = []
for item in items:
try:
restored.append(
DMMessage(
sender_id=str(item.get("sender_id", "")),
ciphertext=str(item.get("ciphertext", "")),
timestamp=float(item.get("timestamp", 0)),
msg_id=str(item.get("msg_id", "")),
delivery_class=str(item.get("delivery_class", "shared")),
sender_seal=str(item.get("sender_seal", "")),
relay_salt=str(item.get("relay_salt", "") or ""),
sender_block_ref=str(item.get("sender_block_ref", "") or ""),
payload_format=str(item.get("payload_format", item.get("format", "dm1")) or "dm1"),
session_welcome=str(item.get("session_welcome", "") or ""),
)
)
except Exception:
continue
for message in restored:
if not message.sender_block_ref:
message.sender_block_ref = self._message_block_ref(message)
if restored:
self._mailboxes[key] = restored
dh_keys = data.get("dh_keys", {})
if isinstance(dh_keys, dict):
self._dh_keys = {str(k): dict(v) for k, v in dh_keys.items() if isinstance(v, dict)}
prekey_bundles = data.get("prekey_bundles", {})
if isinstance(prekey_bundles, dict):
self._prekey_bundles = {
str(k): dict(v) for k, v in prekey_bundles.items() if isinstance(v, dict)
}
prekey_lookup_aliases = data.get("prekey_lookup_aliases", {})
if isinstance(prekey_lookup_aliases, dict):
restored_aliases: dict[str, dict[str, Any]] = {}
alias_records_migrated = False
for key, value in prekey_lookup_aliases.items():
handle = str(key or "").strip()
record = self._coerce_prekey_lookup_alias_record(value)
if not handle or not record:
continue
restored_aliases[handle] = record
if value != record:
alias_records_migrated = True
self._prekey_lookup_aliases = restored_aliases
if alias_records_migrated:
self._dirty = True
now = time.time()
mailbox_bindings = data.get("mailbox_bindings", {})
if isinstance(mailbox_bindings, dict):
if self._metadata_persist_enabled():
restored_bindings: dict[str, dict[str, dict[str, Any]]] = {}
for agent_id, bindings in mailbox_bindings.items():
if not isinstance(bindings, dict):
continue
restored_agent: dict[str, dict[str, Any]] = {}
for kind, entry in bindings.items():
token_hash = ""
last_used = now
if isinstance(entry, dict):
token_hash = str(entry.get("token_hash", "") or "").strip()
last_used = float(entry.get("last_used", now) or now)
else:
token_hash = str(entry or "").strip()
if not token_hash:
continue
normalized = self._coerce_mailbox_binding_entry(
{
"token_hash": token_hash,
"bound_at": float(entry.get("bound_at", last_used) or last_used)
if isinstance(entry, dict)
else last_used,
"last_used": last_used,
"expires_at": float(entry.get("expires_at", 0) or 0)
if isinstance(entry, dict)
else 0,
},
now=now,
)
if normalized:
restored_agent[str(kind)] = normalized
if restored_agent:
restored_bindings[str(agent_id)] = restored_agent
self._mailbox_bindings = defaultdict(dict, restored_bindings)
elif mailbox_bindings:
# Old relay files may still contain persisted mailbox bindings.
# When metadata persistence is disabled we intentionally do not
# restore them, and mark dirty so the next flush rewrites the
# relay state without that graph metadata.
self._dirty = True
witnesses = data.get("witnesses", {})
if isinstance(witnesses, dict):
self._witnesses = defaultdict(
list,
{
str(k): list(v)
for k, v in witnesses.items()
if isinstance(v, list)
},
)
blocks = data.get("blocks", {})
if isinstance(blocks, dict):
for key, values in blocks.items():
if isinstance(values, list):
self._blocks[str(key)] = {
self._canonical_blocked_id(str(v))
for v in values
if str(v or "").strip()
}
nonce_caches = data.get("nonce_caches", {})
if isinstance(nonce_caches, dict) and nonce_caches:
for aid, entries in nonce_caches.items():
if not isinstance(entries, dict):
continue
restored = sorted(
((str(k), float(v)) for k, v in entries.items() if float(v) > now),
key=lambda item: item[1],
)
if restored:
self._nonce_caches[str(aid)] = OrderedDict(restored)
else:
# Backward compatibility: migrate flat nonce_cache → per-agent
nonce_cache = data.get("nonce_cache", {})
if isinstance(nonce_cache, dict):
for composite_key, expiry in nonce_cache.items():
if float(expiry) <= now:
continue
parts = str(composite_key).split(":", 1)
if len(parts) == 2:
aid, nonce_val = parts
if aid not in self._nonce_caches:
self._nonce_caches[aid] = OrderedDict()
self._nonce_caches[aid][nonce_val] = float(expiry)
stats = data.get("stats", {})
if isinstance(stats, dict):
self._stats = {str(k): int(v) for k, v in stats.items() if isinstance(v, (int, float))}
self._stats["messages_in_memory"] = sum(len(v) for v in self._mailboxes.values())
if self._prune_stale_metadata():
self._dirty = True
def _bundle_fingerprint(
self,
*,
dh_pub_key: str,
dh_algo: str,
public_key: str,
public_key_algo: str,
protocol_version: str,
) -> str:
material = "|".join(
[
dh_pub_key,
dh_algo,
public_key,
public_key_algo,
protocol_version,
]
)
return hashlib.sha256(material.encode("utf-8")).hexdigest()
def _advance_prekey_transparency(
self,
*,
agent_id: str,
bundle: dict[str, Any],
signature: str,
public_key: str,
public_key_algo: str,
protocol_version: str,
sequence: int,
existing: dict[str, Any] | None,
) -> dict[str, Any]:
previous_head = str((existing or {}).get("prekey_transparency_head", "") or "").strip().lower()
previous_size = int((existing or {}).get("prekey_transparency_size", 0) or 0)
publication_fingerprint = transparency_fingerprint_for_bundle_record(
{
"agent_id": agent_id,
"bundle": bundle,
"signature": signature,
"public_key": public_key,
"public_key_algo": public_key_algo,
"protocol_version": protocol_version,
"sequence": int(sequence),
}
)
next_size = previous_size + 1
head_payload = {
"agent_id": agent_id,
"sequence": int(sequence),
"signed_at": int(bundle.get("signed_at", 0) or 0),
"publication_fingerprint": publication_fingerprint,
"previous_head": previous_head,
"index": next_size,
}
head = hashlib.sha256(_stable_json(head_payload).encode("utf-8")).hexdigest()
history = list((existing or {}).get("prekey_transparency_log") or [])
history.append(
{
"index": next_size,
"sequence": int(sequence),
"signed_at": int(bundle.get("signed_at", 0) or 0),
"publication_fingerprint": publication_fingerprint,
"previous_head": previous_head,
"head": head,
"observed_at": int(time.time()),
}
)
return {
"prekey_transparency_head": head,
"prekey_transparency_size": next_size,
"prekey_transparency_fingerprint": publication_fingerprint,
"prekey_transparency_log": history[-16:],
}
def register_dh_key(
self,
agent_id: str,
dh_pub_key: str,
dh_algo: str,
timestamp: int,
signature: str,
public_key: str,
public_key_algo: str,
protocol_version: str,
sequence: int,
) -> tuple[bool, str, dict[str, Any] | None]:
"""Register/update an agent's DH public key bundle with replay protection."""
fingerprint = self._bundle_fingerprint(
dh_pub_key=dh_pub_key,
dh_algo=dh_algo,
public_key=public_key,
public_key_algo=public_key_algo,
protocol_version=protocol_version,
)
with self._lock:
self._refresh_from_shared_relay()
existing = self._dh_keys.get(agent_id)
if existing:
existing_seq = int(existing.get("sequence", 0) or 0)
existing_ts = int(existing.get("timestamp", 0) or 0)
if sequence <= existing_seq:
metrics_inc("dm_key_replay")
return False, "DM key replay or rollback rejected", None
if timestamp < existing_ts:
metrics_inc("dm_key_stale")
return False, "DM key timestamp is older than the current bundle", None
self._dh_keys[agent_id] = {
"dh_pub_key": dh_pub_key,
"dh_algo": dh_algo,
"timestamp": timestamp,
"signature": signature,
"public_key": public_key,
"public_key_algo": public_key_algo,
"protocol_version": protocol_version,
"sequence": sequence,
"bundle_fingerprint": fingerprint,
}
self._save()
return True, "ok", {
"accepted_sequence": sequence,
"bundle_fingerprint": fingerprint,
}
def get_dh_key(self, agent_id: str) -> dict[str, Any] | None:
with self._lock:
self._refresh_from_shared_relay()
self._prune_stale_metadata()
return self._dh_keys.get(agent_id)
def get_dh_key_by_lookup(self, lookup_token: str) -> tuple[dict[str, Any] | None, str]:
"""Resolve a prekey lookup alias and return the DH key for the resolved agent."""
with self._lock:
self._refresh_from_shared_relay()
self._prune_stale_metadata()
resolved_id = self._resolve_prekey_lookup_alias(lookup_token)
if not resolved_id:
return None, ""
stored = self._dh_keys.get(resolved_id)
if not stored:
return None, ""
return dict(stored), resolved_id
def register_prekey_bundle(
self,
agent_id: str,
bundle: dict[str, Any],
signature: str,
public_key: str,
public_key_algo: str,
protocol_version: str,
sequence: int,
lookup_aliases: list[Any] | None = None,
) -> tuple[bool, str, dict[str, Any] | None]:
ok, reason = _validate_bundle_record(
{
"bundle": bundle,
"public_key": public_key,
"agent_id": agent_id,
}
)
if not ok:
return False, reason, None
with self._lock:
self._refresh_from_shared_relay()
existing = self._prekey_bundles.get(agent_id)
if existing:
existing_seq = int(existing.get("sequence", 0) or 0)
if sequence <= existing_seq:
return False, "Prekey bundle replay or rollback rejected", None
transparency = self._advance_prekey_transparency(
agent_id=agent_id,
bundle=dict(bundle or {}),
signature=signature,
public_key=public_key,
public_key_algo=public_key_algo,
protocol_version=protocol_version,
sequence=int(sequence),
existing=existing,
)
stored = {
"bundle": dict(bundle or {}),
"signature": signature,
"public_key": public_key,
"public_key_algo": public_key_algo,
"protocol_version": protocol_version,
"sequence": int(sequence),
"updated_at": int(time.time()),
**transparency,
}
self._prekey_bundles[agent_id] = stored
if lookup_aliases:
alias_updated_at = float(stored.get("updated_at", time.time()) or time.time())
for alias in lookup_aliases[:16]:
alias_record = self._coerce_prekey_lookup_alias_record(
{
"agent_id": agent_id,
**dict(alias),
}
if isinstance(alias, dict)
else self._make_prekey_lookup_alias_record(agent_id, updated_at=alias_updated_at)
)
handle = str(alias.get("handle", "") if isinstance(alias, dict) else alias or "").strip()
if handle:
self._prekey_lookup_aliases[handle] = self._make_prekey_lookup_alias_record(
agent_id,
updated_at=alias_updated_at,
expires_at=int(alias_record.get("expires_at", 0) or 0),
max_uses=int(alias_record.get("max_uses", 0) or 0),
use_count=int(alias_record.get("use_count", 0) or 0),
last_used_at=float(alias_record.get("last_used_at", 0) or 0),
)
self._save()
return True, "ok", {"accepted_sequence": int(sequence), **transparency}
def get_prekey_bundle(self, agent_id: str) -> dict[str, Any] | None:
with self._lock:
self._refresh_from_shared_relay()
self._prune_stale_metadata()
stored = self._prekey_bundles.get(agent_id)
if not stored:
return None
return dict(stored)
def get_prekey_bundle_by_lookup(self, lookup_token: str) -> tuple[dict[str, Any] | None, str]:
"""Resolve a lookup alias to a prekey bundle. Returns (bundle, agent_id)."""
with self._lock:
self._refresh_from_shared_relay()
self._prune_stale_metadata()
resolved_id = self._resolve_prekey_lookup_alias(lookup_token)
if not resolved_id:
return None, ""
stored = self._prekey_bundles.get(resolved_id)
if not stored:
return None, ""
return dict(stored), resolved_id
def register_prekey_lookup_alias(
self,
alias: str,
agent_id: str,
*,
expires_at: int = 0,
max_uses: int = 0,
use_count: int = 0,
last_used_at: int = 0,
) -> None:
"""Register a lookup alias for an agent's prekey bundle."""
handle = str(alias or "").strip()
aid = str(agent_id or "").strip()
if handle and aid:
with self._lock:
self._refresh_from_shared_relay()
self._prekey_lookup_aliases[handle] = self._make_prekey_lookup_alias_record(
aid,
expires_at=expires_at,
max_uses=max_uses,
use_count=use_count,
last_used_at=last_used_at,
)
self._save()
def unregister_prekey_lookup_alias(self, alias: str) -> bool:
"""Remove an invite-scoped lookup alias from the local relay."""
handle = str(alias or "").strip()
if not handle:
return False
removed = False
with self._lock:
self._refresh_from_shared_relay()
if handle in self._prekey_lookup_aliases:
del self._prekey_lookup_aliases[handle]
removed = True
if removed:
self._save()
return removed
def consume_one_time_prekey(self, agent_id: str) -> dict[str, Any] | None:
"""Atomically claim the next published one-time prekey for a peer bundle."""
claimed: dict[str, Any] | None = None
with self._lock:
self._refresh_from_shared_relay()
stored = self._prekey_bundles.get(agent_id)
if not stored:
return None
bundle = dict(stored.get("bundle") or {})
otks = list(bundle.get("one_time_prekeys") or [])
if not otks:
return dict(stored)
claimed = dict(otks.pop(0) or {})
bundle["one_time_prekeys"] = otks
bundle["one_time_prekey_count"] = len(otks)
stored = dict(stored)
stored["bundle"] = bundle
stored["updated_at"] = int(time.time())
self._prekey_bundles[agent_id] = stored
self._save()
result = dict(stored)
result["claimed_one_time_prekey"] = claimed
return result
def _prune_witnesses(self, target_id: str) -> None:
cutoff = time.time() - self._witness_ttl_seconds()
self._witnesses[target_id] = [
w for w in self._witnesses.get(target_id, []) if float(w.get("timestamp", 0)) >= cutoff
]
if not self._witnesses[target_id]:
del self._witnesses[target_id]
def record_witness(
self,
witness_id: str,
target_id: str,
dh_pub_key: str,
timestamp: int,
) -> tuple[bool, str]:
if not witness_id or not target_id or not dh_pub_key:
return False, "Missing witness_id, target_id, or dh_pub_key"
if witness_id == target_id:
return False, "Cannot witness yourself"
with self._lock:
self._refresh_from_shared_relay()
self._prune_witnesses(target_id)
entries = self._witnesses.get(target_id, [])
for entry in entries:
if entry.get("witness_id") == witness_id and entry.get("dh_pub_key") == dh_pub_key:
return False, "Duplicate witness"
entries.append(
{
"witness_id": witness_id,
"dh_pub_key": dh_pub_key,
"timestamp": int(timestamp),
}
)
self._witnesses[target_id] = entries[-50:]
self._save()
return True, "ok"
def get_witnesses(self, target_id: str, dh_pub_key: str | None = None, limit: int = 5) -> list[dict]:
with self._lock:
self._refresh_from_shared_relay()
self._prune_witnesses(target_id)
entries = list(self._witnesses.get(target_id, []))
if dh_pub_key:
entries = [e for e in entries if e.get("dh_pub_key") == dh_pub_key]
entries = sorted(entries, key=lambda e: e.get("timestamp", 0), reverse=True)
return entries[: max(1, limit)]
def _epoch_bucket(self, ts: float | None = None) -> int:
now = ts if ts is not None else time.time()
return int(now // EPOCH_SECONDS)
def _mailbox_limit_for_class(self, delivery_class: str) -> int:
if delivery_class == "request":
return self._request_mailbox_limit()
if delivery_class == "shared":
return self._shared_mailbox_limit()
return self._self_mailbox_limit()
def _cleanup_expired(self) -> bool:
now = time.time()
changed = False
for mailbox_id in list(self._mailboxes):
fresh = [m for m in self._mailboxes[mailbox_id] if now - m.timestamp < TTL_SECONDS]
if len(fresh) != len(self._mailboxes[mailbox_id]):
changed = True
self._mailboxes[mailbox_id] = fresh
if not self._mailboxes[mailbox_id]:
del self._mailboxes[mailbox_id]
changed = True
self._stats["messages_in_memory"] = sum(len(v) for v in self._mailboxes.values())
return changed
def _total_nonce_count(self) -> int:
return sum(len(c) for c in self._nonce_caches.values())
def _trim_global_nonce_budget(self, *, preferred_agent_id: str = "") -> int:
trimmed = 0
max_entries = self._nonce_cache_max_entries()
preferred_agent_id = str(preferred_agent_id or "").strip()
while self._total_nonce_count() >= max_entries:
oldest_choice: tuple[str, str, float] | None = None
for aid, cache in self._nonce_caches.items():
if not cache:
continue
if preferred_agent_id and aid == preferred_agent_id and len(self._nonce_caches) > 1:
continue
nonce_value, expiry = next(iter(cache.items()))
if oldest_choice is None or float(expiry) < oldest_choice[2]:
oldest_choice = (aid, nonce_value, float(expiry))
if oldest_choice is None and preferred_agent_id:
for aid, cache in self._nonce_caches.items():
if not cache:
continue
nonce_value, expiry = next(iter(cache.items()))
if oldest_choice is None or float(expiry) < oldest_choice[2]:
oldest_choice = (aid, nonce_value, float(expiry))
if oldest_choice is None:
break
aid, nonce_value, _expiry = oldest_choice
cache = self._nonce_caches.get(aid)
if not cache:
continue
cache.pop(nonce_value, None)
if not cache:
self._nonce_caches.pop(aid, None)
trimmed += 1
if trimmed:
metrics_inc("dm_nonce_cache_trimmed")
return trimmed
def consume_nonce(self, agent_id: str, nonce: str, timestamp: int) -> tuple[bool, str]:
nonce = str(nonce or "").strip()
if not nonce:
return False, "Missing nonce"
agent_id = str(agent_id or "").strip()
now = time.time()
with self._lock:
self._refresh_from_shared_relay()
# Expire stale entries across all agents
for aid in list(self._nonce_caches):
cache = self._nonce_caches[aid]
self._nonce_caches[aid] = OrderedDict(
(k, exp) for k, exp in cache.items() if float(exp) > now
)
if not self._nonce_caches[aid]:
del self._nonce_caches[aid]
agent_cache = self._nonce_caches.get(agent_id)
# Replay check
if agent_cache and nonce in agent_cache:
metrics_inc("dm_nonce_replay")
return False, "nonce replay detected"
# Per-agent capacity check
per_agent_max = self._nonce_per_agent_max()
if agent_cache and len(agent_cache) >= per_agent_max:
metrics_inc("dm_nonce_cache_full")
return False, "nonce cache at capacity"
# Global capacity is a soft memory bound. Trim the oldest nonce
# entries first so one busy agent cannot turn the global budget
# into a cross-agent availability choke point.
if self._total_nonce_count() >= self._nonce_cache_max_entries():
self._trim_global_nonce_budget(preferred_agent_id=agent_id)
expiry = max(now + self._nonce_ttl_seconds(), float(timestamp) + self._nonce_ttl_seconds())
if agent_cache is None:
agent_cache = OrderedDict()
self._nonce_caches[agent_id] = agent_cache
agent_cache[nonce] = expiry
agent_cache.move_to_end(nonce)
self._save()
return True, "ok"
def deposit(
self,
*,
sender_id: str,
raw_sender_id: str = "",
recipient_id: str = "",
ciphertext: str,
msg_id: str = "",
delivery_class: str,
recipient_token: str | None = None,
sender_seal: str = "",
relay_salt: str = "",
sender_token_hash: str = "",
payload_format: str = "dm1",
session_welcome: str = "",
) -> dict[str, Any]:
with self._lock:
self._refresh_from_shared_relay()
authority_sender = str(raw_sender_id or sender_id or "").strip()
sender_block_ref = self._sender_block_ref(
authority_sender,
scope=self._sender_block_scope(
recipient_id=recipient_id,
recipient_token=str(recipient_token or ""),
delivery_class=delivery_class,
),
)
blocked_refs = self._sender_block_refs(
authority_sender,
recipient_id=recipient_id,
recipient_token=str(recipient_token or ""),
delivery_class=delivery_class,
)
if recipient_id and any(ref in self._blocks.get(recipient_id, set()) for ref in blocked_refs):
metrics_inc("dm_drop_blocked")
return {"ok": False, "detail": "Recipient is not accepting your messages"}
if len(ciphertext) > int(self._settings().MESH_DM_MAX_MSG_BYTES):
metrics_inc("dm_drop_oversize")
return {
"ok": False,
"detail": f"Message too large ({len(ciphertext)} > {int(self._settings().MESH_DM_MAX_MSG_BYTES)})",
}
self._cleanup_expired()
if delivery_class == "request":
if not sender_token_hash:
return {"ok": False, "detail": "sender_token required for request delivery"}
mailbox_key = self._mailbox_key("requests", recipient_id)
elif delivery_class == "shared":
if not recipient_token:
metrics_inc("dm_claim_invalid")
return {"ok": False, "detail": "recipient_token required for shared delivery"}
mailbox_key = self._hashed_mailbox_token(recipient_token)
else:
return {"ok": False, "detail": "Unsupported delivery_class"}
if len(self._mailboxes[mailbox_key]) >= self._mailbox_limit_for_class(delivery_class):
metrics_inc("dm_drop_full")
return {"ok": False, "detail": "Recipient mailbox full"}
# Anti-spam: per-(sender, recipient) cap on unacked messages.
# A sender who already has the configured number of messages
# parked in this mailbox can't deposit more until the recipient
# pulls (acks) at least one. The same cap is re-enforced on
# inbound replication in ``accept_replica`` so this rule isn't
# bypassable by patching out the local check on a hostile
# sender's relay — see config.py
# MESH_DM_PENDING_PER_SENDER_LIMIT for the threat model.
per_sender_limit = self._per_sender_pending_limit()
pending = self._per_sender_pending_count(
mailbox_key=mailbox_key,
sender_block_ref=sender_block_ref,
)
if pending >= per_sender_limit:
metrics_inc("dm_drop_per_sender_cap")
return {
"ok": False,
"detail": (
f"Recipient already has {pending} unread message"
f"{'s' if pending != 1 else ''} from you. Wait for "
"them to read your messages before sending more."
),
}
if not msg_id:
msg_id = f"dm_{int(time.time() * 1000)}_{secrets.token_hex(6)}"
elif any(m.msg_id == msg_id for m in self._mailboxes[mailbox_key]):
return {"ok": True, "msg_id": msg_id}
relay_sender_id = (
f"sender_token:{sender_token_hash}"
if sender_token_hash
else sender_id
)
self._mailboxes[mailbox_key].append(
DMMessage(
sender_id=relay_sender_id,
ciphertext=ciphertext,
timestamp=time.time(),
msg_id=msg_id,
delivery_class=delivery_class,
sender_seal=sender_seal,
sender_block_ref=sender_block_ref,
payload_format=str(payload_format or "dm1"),
session_welcome=str(session_welcome or ""),
)
)
self._stats["messages_in_memory"] = sum(len(v) for v in self._mailboxes.values())
self._save()
return {"ok": True, "msg_id": msg_id}
def accept_replica(
self,
*,
envelope: dict[str, Any],
originating_peer_url: str = "",
) -> dict[str, Any]:
"""Receive a DM envelope replicated from a peer relay.
Cross-node mailbox replication entry point. When a sender's local
relay accepts a ``deposit`` and pushes the envelope to
``MESH_RELAY_PEERS`` (so the recipient can log into any peer
node and find their messages), each receiving peer calls
``accept_replica`` to ingest it.
The per-(sender, recipient) cap is re-enforced HERE. That's what
makes the rule a NETWORK rule rather than a client-side honor
system: a hostile sender who patches out the local ``deposit``
check still can't get a 3rd unacked message to spread, because
every honest peer enforces the same cap on inbound replicas.
Result: hostile relays can hold extras locally, but those extras
never reach any node a legitimate recipient is polling from.
Returns the same shape as ``deposit`` so the calling endpoint can
forward the result back to the originating peer.
"""
if not isinstance(envelope, dict):
return {"ok": False, "detail": "envelope must be an object"}
msg_id = str(envelope.get("msg_id", "") or "").strip()
mailbox_key = str(envelope.get("mailbox_key", "") or "").strip()
sender_block_ref = str(envelope.get("sender_block_ref", "") or "").strip()
ciphertext = str(envelope.get("ciphertext", "") or "")
if not msg_id or not mailbox_key or not sender_block_ref or not ciphertext:
return {"ok": False, "detail": "envelope missing required fields"}
with self._lock:
self._refresh_from_shared_relay()
self._cleanup_expired()
# Idempotent — if we already hold this exact msg_id, the
# replication round-tripped or a peer pushed the same
# envelope through multiple paths. Accept silently.
if any(m.msg_id == msg_id for m in self._mailboxes.get(mailbox_key, [])):
metrics_inc("dm_replica_duplicate")
return {"ok": True, "msg_id": msg_id, "duplicate": True}
# Same per-class cap as the deposit path — defense in depth
# against a peer that wraps a "deposit" as a "replica" to
# bypass the class limit.
delivery_class = str(envelope.get("delivery_class", "") or "")
if delivery_class in ("request", "shared", "self"):
class_limit = self._mailbox_limit_for_class(delivery_class)
else:
class_limit = self._shared_mailbox_limit()
if len(self._mailboxes.get(mailbox_key, [])) >= class_limit:
metrics_inc("dm_replica_drop_full")
return {"ok": False, "detail": "Recipient mailbox full"}
# THE network rule: per-(sender, recipient) anti-spam cap.
per_sender_limit = self._per_sender_pending_limit()
pending = self._per_sender_pending_count(
mailbox_key=mailbox_key,
sender_block_ref=sender_block_ref,
)
if pending >= per_sender_limit:
metrics_inc("dm_replica_drop_per_sender_cap")
# Returning a structured rejection — the sender's relay
# learns its envelope was rejected by an honest peer and
# can stop trying to push it.
return {
"ok": False,
"detail": (
"Per-sender cap reached on this relay; refusing replica"
),
"cap_violation": True,
"pending": pending,
"limit": per_sender_limit,
}
# Accept the replica into the local mailbox.
self._mailboxes[mailbox_key].append(
DMMessage(
sender_id=str(envelope.get("sender_id", "") or ""),
ciphertext=ciphertext,
timestamp=float(envelope.get("timestamp", time.time()) or time.time()),
msg_id=msg_id,
delivery_class=str(envelope.get("delivery_class", "shared") or "shared"),
sender_seal=str(envelope.get("sender_seal", "") or ""),
relay_salt=str(envelope.get("relay_salt", "") or ""),
sender_block_ref=sender_block_ref,
payload_format=str(envelope.get("payload_format", "dm1") or "dm1"),
session_welcome=str(envelope.get("session_welcome", "") or ""),
)
)
self._stats["messages_in_memory"] = sum(len(v) for v in self._mailboxes.values())
self._save()
metrics_inc("dm_replica_accepted")
return {"ok": True, "msg_id": msg_id}
def envelope_for_replication(
self,
*,
mailbox_key: str,
msg_id: str,
) -> dict[str, Any] | None:
"""Return the wire-form envelope for a stored message, suitable
for POSTing to a peer relay's replicate-envelope endpoint.
Returns ``None`` if the message isn't in the mailbox (already
acked, expired, never existed). The caller holds the
responsibility for transport security (Tor SOCKS for .onion
peers, per-peer HMAC) and for not leaking the envelope to
clearnet peers when private transport is required.
"""
with self._lock:
for m in self._mailboxes.get(mailbox_key, []):
if m.msg_id == msg_id:
return {
"msg_id": m.msg_id,
"mailbox_key": mailbox_key,
"sender_id": m.sender_id,
"sender_block_ref": m.sender_block_ref,
"sender_seal": m.sender_seal,
"ciphertext": m.ciphertext,
"timestamp": m.timestamp,
"delivery_class": m.delivery_class,
"relay_salt": m.relay_salt,
"payload_format": m.payload_format,
"session_welcome": m.session_welcome,
}
return None
def is_blocked(self, recipient_id: str, sender_id: str) -> bool:
with self._lock:
self._refresh_from_shared_relay()
if not recipient_id:
return False
blocked_refs = self._sender_block_refs(
sender_id,
recipient_id=recipient_id,
delivery_class="request",
)
return any(ref in self._blocks.get(recipient_id, set()) for ref in blocked_refs)
def _collect_from_keys(
self, keys: list[str], *, destructive: bool, limit: int = 0,
) -> tuple[list[dict[str, Any]], bool]:
messages: list[DMMessage] = []
seen: set[str] = set()
popped: dict[str, list[DMMessage]] = {}
for key in keys:
if destructive:
raw = self._mailboxes.pop(key, [])
popped[key] = raw
else:
raw = list(self._mailboxes.get(key, []))
for message in raw:
if message.msg_id in seen:
continue
seen.add(message.msg_id)
messages.append(message)
sorted_messages = sorted(messages, key=lambda item: item.timestamp)
has_more = False
if limit > 0 and len(sorted_messages) > limit:
has_more = True
kept = sorted_messages[:limit]
if destructive:
kept_ids = {m.msg_id for m in kept}
for key, original in popped.items():
remaining = [m for m in original if m.msg_id not in kept_ids]
if remaining:
self._mailboxes.setdefault(key, []).extend(remaining)
sorted_messages = kept
if destructive:
self._stats["messages_in_memory"] = sum(len(v) for v in self._mailboxes.values())
self._save()
result = [
{
"sender_id": message.sender_id,
"ciphertext": message.ciphertext,
"timestamp": message.timestamp,
"msg_id": message.msg_id,
"delivery_class": message.delivery_class,
"sender_seal": message.sender_seal,
"format": message.payload_format,
"session_welcome": message.session_welcome,
}
for message in sorted_messages
]
return result, has_more
def collect_claims(
self, agent_id: str, claims: list[dict[str, Any]], *, limit: int = 0,
) -> tuple[list[dict[str, Any]], bool]:
with self._lock:
self._refresh_from_shared_relay()
self._cleanup_expired()
keys: list[str] = []
for claim in claims[:32]:
keys.extend(self._mailbox_keys_for_claim(agent_id, claim))
return self._collect_from_keys(list(dict.fromkeys(keys)), destructive=True, limit=limit)
def count_claims(self, agent_id: str, claims: list[dict[str, Any]]) -> int:
with self._lock:
self._refresh_from_shared_relay()
self._cleanup_expired()
keys: list[str] = []
for claim in claims[:32]:
keys.extend(self._mailbox_keys_for_claim(agent_id, claim))
messages, _ = self._collect_from_keys(list(dict.fromkeys(keys)), destructive=False)
return len(messages)
def claim_message_ids(self, agent_id: str, claims: list[dict[str, Any]]) -> set[str]:
with self._lock:
self._refresh_from_shared_relay()
self._cleanup_expired()
keys: list[str] = []
for claim in claims[:32]:
keys.extend(self._mailbox_keys_for_claim(agent_id, claim))
messages, _ = self._collect_from_keys(list(dict.fromkeys(keys)), destructive=False)
return {
str(message.get("msg_id", "") or "")
for message in messages
if str(message.get("msg_id", "") or "")
}
def collect_legacy(
self, agent_id: str | None = None, agent_token: str | None = None, *, limit: int = 0,
) -> tuple[list[dict[str, Any]], bool]:
with self._lock:
self._refresh_from_shared_relay()
self._cleanup_expired()
if not agent_token:
return [], False
keys = [self._pepper_token(agent_token), agent_token]
return self._collect_from_keys(list(dict.fromkeys(keys)), destructive=True, limit=limit)
def count_legacy(self, agent_id: str | None = None, agent_token: str | None = None) -> int:
with self._lock:
self._refresh_from_shared_relay()
self._cleanup_expired()
if not agent_token:
return 0
keys = [self._pepper_token(agent_token), agent_token]
messages, _ = self._collect_from_keys(list(dict.fromkeys(keys)), destructive=False)
return len(messages)
def block(self, agent_id: str, blocked_id: str) -> None:
with self._lock:
self._refresh_from_shared_relay()
blocked_ref = self._canonical_blocked_id(
blocked_id,
scope=self._sender_block_scope(recipient_id=agent_id, delivery_class="request"),
)
if not blocked_ref:
return
self._blocks[agent_id].add(blocked_ref)
blocked_refs = {blocked_ref}
blocked_label = str(blocked_id or "").strip()
if blocked_label and not blocked_label.startswith("ref:"):
blocked_refs.add(self._legacy_sender_block_ref(blocked_label))
purge_keys = self._legacy_token_candidates(agent_id)
bound_request = self._bound_mailbox_key(agent_id, "requests")
bound_self = self._bound_mailbox_key(agent_id, "self")
if bound_request:
purge_keys.append(bound_request)
if bound_self:
purge_keys.append(bound_self)
purge_keys.extend(
[
self._mailbox_key("self", agent_id),
self._mailbox_key("requests", agent_id),
self._mailbox_key("self", agent_id, self._epoch_bucket() - 1),
self._mailbox_key("requests", agent_id, self._epoch_bucket() - 1),
]
)
for key in set(purge_keys):
if key in self._mailboxes:
self._mailboxes[key] = [
m for m in self._mailboxes[key] if self._message_block_ref(m) not in blocked_refs
]
self._stats["messages_in_memory"] = sum(len(v) for v in self._mailboxes.values())
self._save()
def unblock(self, agent_id: str, blocked_id: str) -> None:
with self._lock:
self._refresh_from_shared_relay()
blocked_ref = self._canonical_blocked_id(
blocked_id,
scope=self._sender_block_scope(recipient_id=agent_id, delivery_class="request"),
)
if not blocked_ref:
return
self._blocks[agent_id].discard(blocked_ref)
blocked_label = str(blocked_id or "").strip()
if blocked_label and not blocked_label.startswith("ref:"):
self._blocks[agent_id].discard(self._legacy_sender_block_ref(blocked_label))
self._save()
dm_relay = DMRelay()