"""Metadata-minimized DM relay for request and shared mailboxes. This relay never decrypts application payloads. In secure mode it keeps pending ciphertext in memory only and persists just the minimum metadata needed for continuity: accepted DH bundles, block lists, witness data, and nonce replay windows. """ from __future__ import annotations import atexit import hashlib import json import logging import os import secrets import threading import time from collections import OrderedDict, defaultdict from dataclasses import dataclass from pathlib import Path from typing import Any from services.config import get_settings from services.mesh.mesh_metrics import increment as metrics_inc from services.mesh.mesh_wormhole_prekey import ( _validate_bundle_record, transparency_fingerprint_for_bundle_record, ) from services.mesh.mesh_secure_storage import read_secure_json, write_secure_json TTL_SECONDS = 3600 EPOCH_SECONDS = 6 * 60 * 60 BACKEND_DIR = Path(__file__).resolve().parents[2] DEFAULT_DATA_DIR = BACKEND_DIR / "data" DEFAULT_RELAY_FILE = DEFAULT_DATA_DIR / "dm_relay.json" DATA_DIR = DEFAULT_DATA_DIR RELAY_FILE = DEFAULT_RELAY_FILE logger = logging.getLogger(__name__) def _stable_json(value: Any) -> str: return json.dumps(value, sort_keys=True, separators=(",", ":")) def _get_token_pepper() -> str: """Read token pepper lazily so auto-generated values from startup audit take effect.""" pepper = os.environ.get("MESH_DM_TOKEN_PEPPER", "").strip() if not pepper: try: from services.config import get_settings from services.env_check import _ensure_dm_token_pepper pepper = _ensure_dm_token_pepper(get_settings()) except Exception: pepper = os.environ.get("MESH_DM_TOKEN_PEPPER", "").strip() if not pepper: raise RuntimeError("MESH_DM_TOKEN_PEPPER is unavailable at runtime") return pepper @dataclass class DMMessage: sender_id: str ciphertext: str timestamp: float msg_id: str delivery_class: str sender_seal: str = "" relay_salt: str = "" sender_block_ref: str = "" payload_format: str = "dm1" session_welcome: str = "" class DMRelay: """Relay for encrypted request/shared mailboxes.""" def __init__(self) -> None: self._lock = threading.RLock() self._mailboxes: dict[str, list[DMMessage]] = defaultdict(list) self._dh_keys: dict[str, dict[str, Any]] = {} self._prekey_bundles: dict[str, dict[str, Any]] = {} self._mailbox_bindings: dict[str, dict[str, Any]] = defaultdict(dict) self._witnesses: dict[str, list[dict[str, Any]]] = defaultdict(list) self._blocks: dict[str, set[str]] = defaultdict(set) self._nonce_caches: dict[str, OrderedDict[str, float]] = {} """Per-agent nonce replay caches — keyed by agent_id, values are OrderedDicts of nonce→expiry.""" self._prekey_lookup_aliases: dict[str, dict[str, Any]] = {} """Invite-scoped lookup handle → agent_id for prekey bundle fetch without stable identity.""" self._stats: dict[str, int] = {"messages_in_memory": 0} self._dirty = False self._save_timer: threading.Timer | None = None self._last_persist_error = "" self._SAVE_INTERVAL = 5.0 atexit.register(self._flush) self._load() def _settings(self): return get_settings() def _persist_spool_enabled(self) -> bool: return bool(self._settings().MESH_DM_PERSIST_SPOOL) def _relay_file(self) -> Path: # Unit tests frequently monkeypatch the module-level relay file so each # relay instance stays isolated from the shared runtime spool path. module_override = Path(RELAY_FILE) if module_override != DEFAULT_RELAY_FILE: return module_override.expanduser().resolve() override = str(getattr(self._settings(), "MESH_DM_RELAY_FILE_PATH", "") or "").strip() if override: override_path = Path(override).expanduser() if not override_path.is_absolute(): override_path = BACKEND_DIR / override_path return override_path.resolve() return RELAY_FILE def _relay_data_dir(self) -> Path: return self._relay_file().parent def _auto_reload_enabled(self) -> bool: if Path(RELAY_FILE) != DEFAULT_RELAY_FILE: return False return bool(getattr(self._settings(), "MESH_DM_RELAY_AUTO_RELOAD", False)) def _refresh_from_shared_relay(self) -> None: if self._auto_reload_enabled(): self._reload_snapshot_from_shared_relay() def _reload_snapshot_from_shared_relay(self) -> None: relay_file = self._relay_file() fresh_mailboxes: defaultdict[str, list[DMMessage]] = defaultdict(list) fresh_dh_keys: dict[str, dict[str, Any]] = {} fresh_prekey_bundles: dict[str, dict[str, Any]] = {} fresh_mailbox_bindings: defaultdict[str, dict[str, Any]] = defaultdict(dict) fresh_witnesses: defaultdict[str, list[dict[str, Any]]] = defaultdict(list) fresh_blocks: defaultdict[str, set[str]] = defaultdict(set) fresh_nonce_caches: dict[str, OrderedDict[str, float]] = {} fresh_prekey_lookup_aliases: dict[str, dict[str, Any]] = {} fresh_stats: dict[str, int] = {"messages_in_memory": 0} current_mailboxes = defaultdict(list, {k: list(v) for k, v in self._mailboxes.items()}) current_bindings = defaultdict( dict, { str(agent_id): { str(kind): dict(entry) for kind, entry in bindings.items() if isinstance(entry, dict) } for agent_id, bindings in self._mailbox_bindings.items() if isinstance(bindings, dict) }, ) if not relay_file.exists(): if not self._persist_spool_enabled(): fresh_mailboxes = current_mailboxes if not self._metadata_persist_enabled(): fresh_mailbox_bindings = current_bindings self._mailboxes = fresh_mailboxes self._dh_keys = fresh_dh_keys self._prekey_bundles = fresh_prekey_bundles self._mailbox_bindings = fresh_mailbox_bindings self._witnesses = fresh_witnesses self._blocks = fresh_blocks self._nonce_caches = fresh_nonce_caches self._prekey_lookup_aliases = fresh_prekey_lookup_aliases self._stats = fresh_stats return try: data = read_secure_json(relay_file, lambda: {}) except Exception: return if self._persist_spool_enabled(): mailboxes = data.get("mailboxes", {}) if isinstance(mailboxes, dict): for key, items in mailboxes.items(): if not isinstance(items, list): continue restored: list[DMMessage] = [] for item in items: try: restored.append( DMMessage( sender_id=str(item.get("sender_id", "")), ciphertext=str(item.get("ciphertext", "")), timestamp=float(item.get("timestamp", 0)), msg_id=str(item.get("msg_id", "")), delivery_class=str(item.get("delivery_class", "shared")), sender_seal=str(item.get("sender_seal", "")), relay_salt=str(item.get("relay_salt", "") or ""), sender_block_ref=str(item.get("sender_block_ref", "") or ""), payload_format=str(item.get("payload_format", item.get("format", "dm1")) or "dm1"), session_welcome=str(item.get("session_welcome", "") or ""), ) ) except Exception: continue for message in restored: if not message.sender_block_ref: message.sender_block_ref = self._message_block_ref(message) if restored: fresh_mailboxes[str(key)] = restored else: if not self._persist_spool_enabled(): fresh_mailboxes = current_mailboxes dh_keys = data.get("dh_keys", {}) if isinstance(dh_keys, dict): fresh_dh_keys = {str(k): dict(v) for k, v in dh_keys.items() if isinstance(v, dict)} prekey_bundles = data.get("prekey_bundles", {}) if isinstance(prekey_bundles, dict): fresh_prekey_bundles = { str(k): dict(v) for k, v in prekey_bundles.items() if isinstance(v, dict) } prekey_lookup_aliases = data.get("prekey_lookup_aliases", {}) if isinstance(prekey_lookup_aliases, dict): for key, value in prekey_lookup_aliases.items(): handle = str(key or "").strip() record = self._coerce_prekey_lookup_alias_record(value) if handle and record: fresh_prekey_lookup_aliases[handle] = record now = time.time() mailbox_bindings = data.get("mailbox_bindings", {}) if isinstance(mailbox_bindings, dict) and self._metadata_persist_enabled(): for agent_id, bindings in mailbox_bindings.items(): if not isinstance(bindings, dict): continue restored_agent: dict[str, dict[str, Any]] = {} for kind, entry in bindings.items(): token_hash = "" last_used = now if isinstance(entry, dict): token_hash = str(entry.get("token_hash", "") or "").strip() last_used = float(entry.get("last_used", now) or now) else: token_hash = str(entry or "").strip() if token_hash: normalized = self._coerce_mailbox_binding_entry( { "token_hash": token_hash, "bound_at": float(entry.get("bound_at", last_used) or last_used) if isinstance(entry, dict) else last_used, "last_used": last_used, "expires_at": float(entry.get("expires_at", 0) or 0) if isinstance(entry, dict) else 0, }, now=now, ) if normalized: restored_agent[str(kind)] = normalized if restored_agent: fresh_mailbox_bindings[str(agent_id)] = restored_agent elif not self._metadata_persist_enabled(): fresh_mailbox_bindings = current_bindings witnesses = data.get("witnesses", {}) if isinstance(witnesses, dict): fresh_witnesses = defaultdict( list, {str(k): list(v) for k, v in witnesses.items() if isinstance(v, list)}, ) blocks = data.get("blocks", {}) if isinstance(blocks, dict): for key, values in blocks.items(): if isinstance(values, list): fresh_blocks[str(key)] = { self._canonical_blocked_id(str(v)) for v in values if str(v or "").strip() } nonce_caches = data.get("nonce_caches", {}) if isinstance(nonce_caches, dict) and nonce_caches: for aid, entries in nonce_caches.items(): if not isinstance(entries, dict): continue restored = sorted( ((str(k), float(v)) for k, v in entries.items() if float(v) > now), key=lambda item: item[1], ) if restored: fresh_nonce_caches[str(aid)] = OrderedDict(restored) else: nonce_cache = data.get("nonce_cache", {}) if isinstance(nonce_cache, dict): for composite_key, expiry in nonce_cache.items(): if float(expiry) <= now: continue parts = str(composite_key).split(":", 1) if len(parts) == 2: aid, nonce_val = parts if aid not in fresh_nonce_caches: fresh_nonce_caches[aid] = OrderedDict() fresh_nonce_caches[aid][nonce_val] = float(expiry) stats = data.get("stats", {}) if isinstance(stats, dict): fresh_stats = {str(k): int(v) for k, v in stats.items() if isinstance(v, (int, float))} self._mailboxes = fresh_mailboxes self._dh_keys = fresh_dh_keys self._prekey_bundles = fresh_prekey_bundles self._mailbox_bindings = fresh_mailbox_bindings self._witnesses = fresh_witnesses self._blocks = fresh_blocks self._nonce_caches = fresh_nonce_caches self._prekey_lookup_aliases = fresh_prekey_lookup_aliases self._stats = fresh_stats self._stats["messages_in_memory"] = sum(len(v) for v in self._mailboxes.values()) if self._prune_stale_metadata(): self._dirty = True def _request_mailbox_limit(self) -> int: return max(1, int(self._settings().MESH_DM_REQUEST_MAILBOX_LIMIT)) def _shared_mailbox_limit(self) -> int: return max(1, int(self._settings().MESH_DM_SHARED_MAILBOX_LIMIT)) def _self_mailbox_limit(self) -> int: return max(1, int(self._settings().MESH_DM_SELF_MAILBOX_LIMIT)) def _nonce_ttl_seconds(self) -> int: return max(30, int(self._settings().MESH_DM_NONCE_TTL_S)) def _nonce_cache_max_entries(self) -> int: return max(1, int(getattr(self._settings(), "MESH_DM_NONCE_CACHE_MAX", 4096))) def _nonce_per_agent_max(self) -> int: return max(1, int(getattr(self._settings(), "MESH_DM_NONCE_PER_AGENT_MAX", 256))) def _dm_key_ttl_seconds(self) -> int: return max(1, int(getattr(self._settings(), "MESH_DM_KEY_TTL_DAYS", 30) or 30)) * 86400 def _prekey_lookup_alias_ttl_seconds(self) -> int: return max( 1, int(getattr(self._settings(), "MESH_DM_PREKEY_LOOKUP_ALIAS_TTL_DAYS", 14) or 14), ) * 86400 def _witness_ttl_seconds(self) -> int: return max(1, int(getattr(self._settings(), "MESH_DM_WITNESS_TTL_DAYS", 14) or 14)) * 86400 def _mailbox_binding_ttl_seconds(self) -> int: return max(1, int(getattr(self._settings(), "MESH_DM_BINDING_TTL_DAYS", 3) or 3)) * 86400 def _mailbox_binding_idle_ttl_seconds(self) -> int: return min(self._mailbox_binding_ttl_seconds(), 12 * 60 * 60) def _mailbox_binding_refresh_after_seconds(self) -> int: return max(15 * 60, min(self._mailbox_binding_ttl_seconds(), 12 * 60 * 60)) def _mailbox_binding_expires_at(self, entry: dict[str, Any]) -> float: bound_at = float(entry.get("bound_at", 0) or 0) last_used = float(entry.get("last_used", bound_at) or bound_at) if bound_at <= 0: return 0.0 return min( bound_at + self._mailbox_binding_ttl_seconds(), last_used + self._mailbox_binding_idle_ttl_seconds(), ) def _coerce_mailbox_binding_entry(self, entry: Any, *, now: float | None = None) -> dict[str, Any]: current = time.time() if now is None else float(now) token_hash = "" bound_at = current last_used = current if isinstance(entry, dict): token_hash = str(entry.get("token_hash", "") or "").strip() bound_at = float(entry.get("bound_at", entry.get("last_used", current)) or current) last_used = float(entry.get("last_used", bound_at) or bound_at) else: token_hash = str(entry or "").strip() if not token_hash: return {} normalized = { "token_hash": token_hash, "bound_at": bound_at, "last_used": last_used, } normalized["expires_at"] = self._mailbox_binding_expires_at(normalized) return normalized def _alias_updated_at_for_agent(self, agent_id: str) -> float: stored = self._prekey_bundles.get(str(agent_id or "").strip(), {}) if isinstance(stored, dict): return float(stored.get("updated_at", stored.get("timestamp", time.time())) or time.time()) return float(time.time()) def _make_prekey_lookup_alias_record( self, agent_id: str, *, updated_at: float | None = None, expires_at: int = 0, max_uses: int = 0, use_count: int = 0, last_used_at: float = 0, ) -> dict[str, Any]: aid = str(agent_id or "").strip() if not aid: return {} if updated_at is None: updated_at = self._alias_updated_at_for_agent(aid) return { "agent_id": aid, "updated_at": float(updated_at or self._alias_updated_at_for_agent(aid)), "expires_at": max(0, int(expires_at or 0)), "max_uses": max(0, int(max_uses or 0)), "use_count": max(0, int(use_count or 0)), "last_used_at": float(last_used_at or 0), } def _coerce_prekey_lookup_alias_record(self, value: Any) -> dict[str, Any]: if isinstance(value, dict): aid = str(value.get("agent_id", "") or "").strip() if not aid: return {} updated_at = float( value.get("updated_at", value.get("last_used", self._alias_updated_at_for_agent(aid))) or self._alias_updated_at_for_agent(aid) ) return self._make_prekey_lookup_alias_record( aid, updated_at=updated_at, expires_at=int(value.get("expires_at", 0) or 0), max_uses=int(value.get("max_uses", 0) or 0), use_count=int(value.get("use_count", value.get("uses", 0)) or 0), last_used_at=float(value.get("last_used_at", value.get("last_used", 0)) or 0), ) aid = str(value or "").strip() if not aid: return {} return self._make_prekey_lookup_alias_record(aid) def _resolve_prekey_lookup_alias(self, lookup_token: str) -> str: handle = str(lookup_token or "").strip() if not handle: return "" record = self._coerce_prekey_lookup_alias_record(self._prekey_lookup_aliases.get(handle, {})) if not record: return "" now = time.time() expires_at = int(record.get("expires_at", 0) or 0) max_uses = int(record.get("max_uses", 0) or 0) use_count = int(record.get("use_count", 0) or 0) if (expires_at > 0 and now > expires_at) or (max_uses > 0 and use_count >= max_uses): self._prekey_lookup_aliases.pop(handle, None) self._save() return "" updated = self._make_prekey_lookup_alias_record( str(record.get("agent_id", "") or "").strip(), updated_at=float(record.get("updated_at", self._alias_updated_at_for_agent(str(record.get("agent_id", "") or "").strip())) or now), expires_at=expires_at, max_uses=max_uses, use_count=use_count + 1, last_used_at=now, ) self._prekey_lookup_aliases[handle] = updated self._save() try: from services.mesh.mesh_wormhole_identity import record_prekey_lookup_handle_use record_prekey_lookup_handle_use(handle, now=int(now)) except Exception: pass return str(updated.get("agent_id", "") or "").strip() def _pepper_token(self, token: str) -> str: material = token pepper = _get_token_pepper() if pepper: material = f"{pepper}|{token}" return hashlib.sha256(material.encode("utf-8")).hexdigest() def _legacy_sender_block_ref(self, sender_id: str) -> str: sender = str(sender_id or "").strip() if not sender: return "" return "ref:" + self._pepper_token(f"block|{sender}") def _sender_block_scope( self, *, recipient_id: str = "", recipient_token: str = "", delivery_class: str = "", ) -> str: recipient = str(recipient_id or "").strip() if recipient: return f"recipient|{recipient}" token = str(recipient_token or "").strip() if token and str(delivery_class or "").strip().lower() == "shared": return f"shared|{self._hashed_mailbox_token(token)}" return "" def _sender_block_ref(self, sender_id: str, *, scope: str = "") -> str: sender = str(sender_id or "").strip() if not sender: return "" material = f"block|{scope}|{sender}" if scope else f"block|{sender}" return "ref:" + self._pepper_token(material) def _sender_block_refs( self, sender_id: str, *, recipient_id: str = "", recipient_token: str = "", delivery_class: str = "", ) -> set[str]: refs: set[str] = set() legacy = self._legacy_sender_block_ref(sender_id) if legacy: refs.add(legacy) scoped = self._sender_block_ref( sender_id, scope=self._sender_block_scope( recipient_id=recipient_id, recipient_token=recipient_token, delivery_class=delivery_class, ), ) if scoped: refs.add(scoped) return refs def _canonical_blocked_id(self, blocked_id: str, *, scope: str = "") -> str: blocked = str(blocked_id or "").strip() if not blocked: return "" if blocked.startswith("ref:"): return blocked return self._sender_block_ref(blocked, scope=scope) def _message_block_ref(self, message: DMMessage) -> str: block_ref = str(getattr(message, "sender_block_ref", "") or "").strip() if block_ref: return block_ref sender_id = str(message.sender_id or "").strip() if not sender_id or sender_id.startswith("sealed:") or sender_id.startswith("sender_token:"): return "" return self._legacy_sender_block_ref(sender_id) def _mailbox_key(self, mailbox_type: str, mailbox_value: str, epoch: int | None = None) -> str: if mailbox_type in {"self", "requests"}: bucket = self._epoch_bucket() if epoch is None else int(epoch) identifier = f"{mailbox_type}|{bucket}|{mailbox_value}" else: identifier = f"{mailbox_type}|{mailbox_value}" return self._pepper_token(identifier) def _hashed_mailbox_token(self, token: str) -> str: return hashlib.sha256(str(token or "").encode("utf-8")).hexdigest() def _remember_mailbox_binding(self, agent_id: str, mailbox_type: str, token: str) -> str: if self._prune_stale_mailbox_bindings(): self._save() now = time.time() agent_key = str(agent_id or "").strip() mailbox_key = str(mailbox_type or "").strip().lower() token_hash = self._hashed_mailbox_token(token) current = self._coerce_mailbox_binding_entry( self._mailbox_bindings.get(agent_key, {}).get(mailbox_key, {}), now=now, ) refreshed = { "token_hash": token_hash, "bound_at": now, "last_used": now, } if current and str(current.get("token_hash", "") or "") == token_hash: refreshed["bound_at"] = float(current.get("bound_at", now) or now) if (now - refreshed["bound_at"]) >= self._mailbox_binding_refresh_after_seconds(): refreshed["bound_at"] = now refreshed["expires_at"] = self._mailbox_binding_expires_at(refreshed) self._mailbox_bindings[agent_key][mailbox_key] = refreshed self._save() return token_hash def _bound_mailbox_key(self, agent_id: str, mailbox_type: str) -> str: if self._prune_stale_mailbox_bindings(): self._save() agent_key = str(agent_id or "").strip() mailbox_key = str(mailbox_type or "").strip().lower() entry = self._mailbox_bindings.get(agent_key, {}).get( mailbox_key, {}, ) normalized = self._coerce_mailbox_binding_entry(entry) if normalized and normalized != entry: self._mailbox_bindings[agent_key][mailbox_key] = normalized self._save() return str(normalized.get("token_hash", "") or "") def _mailbox_keys_for_claim(self, agent_id: str, claim: dict[str, Any]) -> list[str]: claim_type = str(claim.get("type", "")).strip().lower() if claim_type == "shared": token = str(claim.get("token", "")).strip() if not token: metrics_inc("dm_claim_invalid") return [] return [self._hashed_mailbox_token(token)] if claim_type == "requests": token = str(claim.get("token", "")).strip() if token: previous_bound = self._bound_mailbox_key(agent_id, "requests") bound_key = self._remember_mailbox_binding(agent_id, "requests", token) epoch = self._epoch_bucket() return [ key for key in [ previous_bound, bound_key, self._mailbox_key("requests", agent_id, epoch), self._mailbox_key("requests", agent_id, epoch - 1), ] if key ] metrics_inc("dm_claim_invalid") return [] if claim_type == "self": token = str(claim.get("token", "")).strip() if token: previous_bound = self._bound_mailbox_key(agent_id, "self") bound_key = self._remember_mailbox_binding(agent_id, "self", token) epoch = self._epoch_bucket() return [ key for key in [ previous_bound, bound_key, self._mailbox_key("self", agent_id, epoch), self._mailbox_key("self", agent_id, epoch - 1), ] if key ] metrics_inc("dm_claim_invalid") return [] metrics_inc("dm_claim_invalid") return [] def mailbox_key_for_delivery( self, *, recipient_id: str, delivery_class: str, recipient_token: str | None = None, ) -> str: with self._lock: self._refresh_from_shared_relay() delivery_class = str(delivery_class or "").strip().lower() if delivery_class == "request": bound_key = self._bound_mailbox_key(recipient_id, "requests") if bound_key: return bound_key return self._mailbox_key("requests", str(recipient_id or "").strip()) if delivery_class == "shared": token = str(recipient_token or "").strip() if not token: raise ValueError("recipient_token required for shared delivery") return self._hashed_mailbox_token(token) raise ValueError("Unsupported delivery_class") def claim_mailbox_keys(self, agent_id: str, claims: list[dict[str, Any]]) -> list[str]: with self._lock: self._refresh_from_shared_relay() if self._prune_stale_mailbox_bindings(): self._save() keys: list[str] = [] for claim in claims[:32]: keys.extend(self._mailbox_keys_for_claim(agent_id, claim)) return list(dict.fromkeys(keys)) def _legacy_mailbox_token(self, agent_id: str, epoch: int) -> str: raw = f"sb_dm|{epoch}|{agent_id}".encode("utf-8") return hashlib.sha256(raw).hexdigest() def _legacy_token_candidates(self, agent_id: str) -> list[str]: epoch = self._epoch_bucket() raw = [self._legacy_mailbox_token(agent_id, epoch), self._legacy_mailbox_token(agent_id, epoch - 1)] peppered = [self._pepper_token(token) for token in raw] return list(dict.fromkeys(peppered + raw)) def _save(self) -> None: """Mark dirty and schedule a coalesced disk write.""" self._dirty = True relay_file = self._relay_file() if self._auto_reload_enabled() or not relay_file.exists() or self._persist_failures_are_fatal(): self._flush() return with self._lock: if self._save_timer is None or not self._save_timer.is_alive(): self._save_timer = threading.Timer(self._SAVE_INTERVAL, self._flush) self._save_timer.daemon = True self._save_timer.start() def _persist_failures_are_fatal(self) -> bool: return bool(os.environ.get("PYTEST_CURRENT_TEST", "").strip()) def _record_persist_failure(self, operation: str, exc: Exception) -> None: self._last_persist_error = f"{operation}:{type(exc).__name__}:{exc}" metrics_inc("dm_relay_persist_failure") logger.exception("dm relay %s failed for %s", operation, self._relay_file()) def _prune_stale_metadata(self) -> bool: """Remove expired relay metadata that should not outlive its retention window.""" now = time.time() key_ttl = self._dm_key_ttl_seconds() alias_ttl = self._prekey_lookup_alias_ttl_seconds() witness_ttl = self._witness_ttl_seconds() changed = False stale_keys = [ aid for aid, entry in self._dh_keys.items() if (now - float(entry.get("timestamp", 0) or 0)) > key_ttl ] for aid in stale_keys: del self._dh_keys[aid] changed = True stale_bundles = [ aid for aid, entry in self._prekey_bundles.items() if (now - float(entry.get("updated_at", entry.get("timestamp", 0)) or 0)) > key_ttl ] for aid in stale_bundles: del self._prekey_bundles[aid] changed = True stale_aliases: list[str] = [] for alias, value in list(self._prekey_lookup_aliases.items()): record = self._coerce_prekey_lookup_alias_record(value) if not record: stale_aliases.append(alias) continue if self._prekey_lookup_aliases.get(alias) != record: self._prekey_lookup_aliases[alias] = record changed = True target = str(record.get("agent_id", "") or "").strip() updated_at = float(record.get("updated_at", self._alias_updated_at_for_agent(target)) or 0) expires_at = int(record.get("expires_at", 0) or 0) max_uses = int(record.get("max_uses", 0) or 0) use_count = int(record.get("use_count", 0) or 0) if ( not target or target not in self._prekey_bundles or (now - updated_at) > alias_ttl or (expires_at > 0 and now > float(expires_at)) or (max_uses > 0 and use_count >= max_uses) ): stale_aliases.append(alias) for alias in stale_aliases: del self._prekey_lookup_aliases[alias] changed = True for target_id in list(self._witnesses): fresh = [ witness for witness in self._witnesses.get(target_id, []) if (now - float(witness.get("timestamp", 0) or 0)) <= witness_ttl ] if len(fresh) != len(self._witnesses.get(target_id, [])): changed = True if fresh: self._witnesses[target_id] = fresh else: del self._witnesses[target_id] if self._prune_stale_mailbox_bindings(now=now): changed = True return changed def _prune_stale_mailbox_bindings(self, *, now: float | None = None) -> bool: current = time.time() if now is None else now changed = False stale_agents: list[str] = [] for agent_id, kinds in self._mailbox_bindings.items(): normalized_updates: dict[str, dict[str, Any]] = {} expired_kinds = [ k for k, v in kinds.items() if not self._coerce_mailbox_binding_entry(v, now=current) or current > self._mailbox_binding_expires_at( self._coerce_mailbox_binding_entry(v, now=current) ) ] for kind, entry in list(kinds.items()): normalized = self._coerce_mailbox_binding_entry(entry, now=current) if normalized and normalized != entry: normalized_updates[kind] = normalized for kind, normalized in normalized_updates.items(): kinds[kind] = normalized changed = True for k in expired_kinds: del kinds[k] changed = True if not kinds: stale_agents.append(agent_id) for agent_id in stale_agents: del self._mailbox_bindings[agent_id] changed = True return changed def _metadata_persist_enabled(self) -> bool: settings = self._settings() return bool(getattr(settings, "MESH_DM_METADATA_PERSIST", False)) and bool( getattr(settings, "MESH_DM_METADATA_PERSIST_ACKNOWLEDGE", False) ) def _flush(self) -> None: """Actually write to disk (called by timer or atexit).""" if not self._dirty: return try: self._prune_stale_metadata() relay_file = self._relay_file() self._relay_data_dir().mkdir(parents=True, exist_ok=True) payload: dict[str, Any] = { "saved_at": int(time.time()), "dh_keys": self._dh_keys, "prekey_bundles": self._prekey_bundles, "prekey_lookup_aliases": self._prekey_lookup_aliases, "witnesses": self._witnesses, "blocks": {k: sorted(v) for k, v in self._blocks.items()}, "nonce_caches": {aid: dict(c) for aid, c in self._nonce_caches.items()}, "stats": self._stats, } if self._metadata_persist_enabled(): payload["mailbox_bindings"] = { agent_id: { mailbox_type: { "token_hash": str(entry.get("token_hash", "") or "").strip(), "bound_at": float(entry.get("bound_at", 0) or 0), "last_used": float(entry.get("last_used", 0) or 0), "expires_at": float(entry.get("expires_at", 0) or 0), } for mailbox_type, entry in bindings.items() if isinstance(entry, dict) and str(entry.get("token_hash", "") or "").strip() } for agent_id, bindings in self._mailbox_bindings.items() if isinstance(bindings, dict) } if self._persist_spool_enabled(): payload["mailboxes"] = { key: [m.__dict__ for m in msgs] for key, msgs in self._mailboxes.items() } write_secure_json(relay_file, payload) self._dirty = False self._last_persist_error = "" except Exception as exc: self._record_persist_failure("flush", exc) if self._persist_failures_are_fatal(): raise def _load(self) -> None: relay_file = self._relay_file() if not relay_file.exists(): return try: data = read_secure_json(relay_file, lambda: {}) except Exception: return if self._persist_spool_enabled(): mailboxes = data.get("mailboxes", {}) if isinstance(mailboxes, dict): for key, items in mailboxes.items(): if not isinstance(items, list): continue restored: list[DMMessage] = [] for item in items: try: restored.append( DMMessage( sender_id=str(item.get("sender_id", "")), ciphertext=str(item.get("ciphertext", "")), timestamp=float(item.get("timestamp", 0)), msg_id=str(item.get("msg_id", "")), delivery_class=str(item.get("delivery_class", "shared")), sender_seal=str(item.get("sender_seal", "")), relay_salt=str(item.get("relay_salt", "") or ""), sender_block_ref=str(item.get("sender_block_ref", "") or ""), payload_format=str(item.get("payload_format", item.get("format", "dm1")) or "dm1"), session_welcome=str(item.get("session_welcome", "") or ""), ) ) except Exception: continue for message in restored: if not message.sender_block_ref: message.sender_block_ref = self._message_block_ref(message) if restored: self._mailboxes[key] = restored dh_keys = data.get("dh_keys", {}) if isinstance(dh_keys, dict): self._dh_keys = {str(k): dict(v) for k, v in dh_keys.items() if isinstance(v, dict)} prekey_bundles = data.get("prekey_bundles", {}) if isinstance(prekey_bundles, dict): self._prekey_bundles = { str(k): dict(v) for k, v in prekey_bundles.items() if isinstance(v, dict) } prekey_lookup_aliases = data.get("prekey_lookup_aliases", {}) if isinstance(prekey_lookup_aliases, dict): restored_aliases: dict[str, dict[str, Any]] = {} alias_records_migrated = False for key, value in prekey_lookup_aliases.items(): handle = str(key or "").strip() record = self._coerce_prekey_lookup_alias_record(value) if not handle or not record: continue restored_aliases[handle] = record if value != record: alias_records_migrated = True self._prekey_lookup_aliases = restored_aliases if alias_records_migrated: self._dirty = True now = time.time() mailbox_bindings = data.get("mailbox_bindings", {}) if isinstance(mailbox_bindings, dict): if self._metadata_persist_enabled(): restored_bindings: dict[str, dict[str, dict[str, Any]]] = {} for agent_id, bindings in mailbox_bindings.items(): if not isinstance(bindings, dict): continue restored_agent: dict[str, dict[str, Any]] = {} for kind, entry in bindings.items(): token_hash = "" last_used = now if isinstance(entry, dict): token_hash = str(entry.get("token_hash", "") or "").strip() last_used = float(entry.get("last_used", now) or now) else: token_hash = str(entry or "").strip() if not token_hash: continue normalized = self._coerce_mailbox_binding_entry( { "token_hash": token_hash, "bound_at": float(entry.get("bound_at", last_used) or last_used) if isinstance(entry, dict) else last_used, "last_used": last_used, "expires_at": float(entry.get("expires_at", 0) or 0) if isinstance(entry, dict) else 0, }, now=now, ) if normalized: restored_agent[str(kind)] = normalized if restored_agent: restored_bindings[str(agent_id)] = restored_agent self._mailbox_bindings = defaultdict(dict, restored_bindings) elif mailbox_bindings: # Old relay files may still contain persisted mailbox bindings. # When metadata persistence is disabled we intentionally do not # restore them, and mark dirty so the next flush rewrites the # relay state without that graph metadata. self._dirty = True witnesses = data.get("witnesses", {}) if isinstance(witnesses, dict): self._witnesses = defaultdict( list, { str(k): list(v) for k, v in witnesses.items() if isinstance(v, list) }, ) blocks = data.get("blocks", {}) if isinstance(blocks, dict): for key, values in blocks.items(): if isinstance(values, list): self._blocks[str(key)] = { self._canonical_blocked_id(str(v)) for v in values if str(v or "").strip() } nonce_caches = data.get("nonce_caches", {}) if isinstance(nonce_caches, dict) and nonce_caches: for aid, entries in nonce_caches.items(): if not isinstance(entries, dict): continue restored = sorted( ((str(k), float(v)) for k, v in entries.items() if float(v) > now), key=lambda item: item[1], ) if restored: self._nonce_caches[str(aid)] = OrderedDict(restored) else: # Backward compatibility: migrate flat nonce_cache → per-agent nonce_cache = data.get("nonce_cache", {}) if isinstance(nonce_cache, dict): for composite_key, expiry in nonce_cache.items(): if float(expiry) <= now: continue parts = str(composite_key).split(":", 1) if len(parts) == 2: aid, nonce_val = parts if aid not in self._nonce_caches: self._nonce_caches[aid] = OrderedDict() self._nonce_caches[aid][nonce_val] = float(expiry) stats = data.get("stats", {}) if isinstance(stats, dict): self._stats = {str(k): int(v) for k, v in stats.items() if isinstance(v, (int, float))} self._stats["messages_in_memory"] = sum(len(v) for v in self._mailboxes.values()) if self._prune_stale_metadata(): self._dirty = True def _bundle_fingerprint( self, *, dh_pub_key: str, dh_algo: str, public_key: str, public_key_algo: str, protocol_version: str, ) -> str: material = "|".join( [ dh_pub_key, dh_algo, public_key, public_key_algo, protocol_version, ] ) return hashlib.sha256(material.encode("utf-8")).hexdigest() def _advance_prekey_transparency( self, *, agent_id: str, bundle: dict[str, Any], signature: str, public_key: str, public_key_algo: str, protocol_version: str, sequence: int, existing: dict[str, Any] | None, ) -> dict[str, Any]: previous_head = str((existing or {}).get("prekey_transparency_head", "") or "").strip().lower() previous_size = int((existing or {}).get("prekey_transparency_size", 0) or 0) publication_fingerprint = transparency_fingerprint_for_bundle_record( { "agent_id": agent_id, "bundle": bundle, "signature": signature, "public_key": public_key, "public_key_algo": public_key_algo, "protocol_version": protocol_version, "sequence": int(sequence), } ) next_size = previous_size + 1 head_payload = { "agent_id": agent_id, "sequence": int(sequence), "signed_at": int(bundle.get("signed_at", 0) or 0), "publication_fingerprint": publication_fingerprint, "previous_head": previous_head, "index": next_size, } head = hashlib.sha256(_stable_json(head_payload).encode("utf-8")).hexdigest() history = list((existing or {}).get("prekey_transparency_log") or []) history.append( { "index": next_size, "sequence": int(sequence), "signed_at": int(bundle.get("signed_at", 0) or 0), "publication_fingerprint": publication_fingerprint, "previous_head": previous_head, "head": head, "observed_at": int(time.time()), } ) return { "prekey_transparency_head": head, "prekey_transparency_size": next_size, "prekey_transparency_fingerprint": publication_fingerprint, "prekey_transparency_log": history[-16:], } def register_dh_key( self, agent_id: str, dh_pub_key: str, dh_algo: str, timestamp: int, signature: str, public_key: str, public_key_algo: str, protocol_version: str, sequence: int, ) -> tuple[bool, str, dict[str, Any] | None]: """Register/update an agent's DH public key bundle with replay protection.""" fingerprint = self._bundle_fingerprint( dh_pub_key=dh_pub_key, dh_algo=dh_algo, public_key=public_key, public_key_algo=public_key_algo, protocol_version=protocol_version, ) with self._lock: self._refresh_from_shared_relay() existing = self._dh_keys.get(agent_id) if existing: existing_seq = int(existing.get("sequence", 0) or 0) existing_ts = int(existing.get("timestamp", 0) or 0) if sequence <= existing_seq: metrics_inc("dm_key_replay") return False, "DM key replay or rollback rejected", None if timestamp < existing_ts: metrics_inc("dm_key_stale") return False, "DM key timestamp is older than the current bundle", None self._dh_keys[agent_id] = { "dh_pub_key": dh_pub_key, "dh_algo": dh_algo, "timestamp": timestamp, "signature": signature, "public_key": public_key, "public_key_algo": public_key_algo, "protocol_version": protocol_version, "sequence": sequence, "bundle_fingerprint": fingerprint, } self._save() return True, "ok", { "accepted_sequence": sequence, "bundle_fingerprint": fingerprint, } def get_dh_key(self, agent_id: str) -> dict[str, Any] | None: with self._lock: self._refresh_from_shared_relay() self._prune_stale_metadata() return self._dh_keys.get(agent_id) def get_dh_key_by_lookup(self, lookup_token: str) -> tuple[dict[str, Any] | None, str]: """Resolve a prekey lookup alias and return the DH key for the resolved agent.""" with self._lock: self._refresh_from_shared_relay() self._prune_stale_metadata() resolved_id = self._resolve_prekey_lookup_alias(lookup_token) if not resolved_id: return None, "" stored = self._dh_keys.get(resolved_id) if not stored: return None, "" return dict(stored), resolved_id def register_prekey_bundle( self, agent_id: str, bundle: dict[str, Any], signature: str, public_key: str, public_key_algo: str, protocol_version: str, sequence: int, lookup_aliases: list[Any] | None = None, ) -> tuple[bool, str, dict[str, Any] | None]: ok, reason = _validate_bundle_record( { "bundle": bundle, "public_key": public_key, "agent_id": agent_id, } ) if not ok: return False, reason, None with self._lock: self._refresh_from_shared_relay() existing = self._prekey_bundles.get(agent_id) if existing: existing_seq = int(existing.get("sequence", 0) or 0) if sequence <= existing_seq: return False, "Prekey bundle replay or rollback rejected", None transparency = self._advance_prekey_transparency( agent_id=agent_id, bundle=dict(bundle or {}), signature=signature, public_key=public_key, public_key_algo=public_key_algo, protocol_version=protocol_version, sequence=int(sequence), existing=existing, ) stored = { "bundle": dict(bundle or {}), "signature": signature, "public_key": public_key, "public_key_algo": public_key_algo, "protocol_version": protocol_version, "sequence": int(sequence), "updated_at": int(time.time()), **transparency, } self._prekey_bundles[agent_id] = stored if lookup_aliases: alias_updated_at = float(stored.get("updated_at", time.time()) or time.time()) for alias in lookup_aliases[:16]: alias_record = self._coerce_prekey_lookup_alias_record( { "agent_id": agent_id, **dict(alias), } if isinstance(alias, dict) else self._make_prekey_lookup_alias_record(agent_id, updated_at=alias_updated_at) ) handle = str(alias.get("handle", "") if isinstance(alias, dict) else alias or "").strip() if handle: self._prekey_lookup_aliases[handle] = self._make_prekey_lookup_alias_record( agent_id, updated_at=alias_updated_at, expires_at=int(alias_record.get("expires_at", 0) or 0), max_uses=int(alias_record.get("max_uses", 0) or 0), use_count=int(alias_record.get("use_count", 0) or 0), last_used_at=float(alias_record.get("last_used_at", 0) or 0), ) self._save() return True, "ok", {"accepted_sequence": int(sequence), **transparency} def get_prekey_bundle(self, agent_id: str) -> dict[str, Any] | None: with self._lock: self._refresh_from_shared_relay() self._prune_stale_metadata() stored = self._prekey_bundles.get(agent_id) if not stored: return None return dict(stored) def get_prekey_bundle_by_lookup(self, lookup_token: str) -> tuple[dict[str, Any] | None, str]: """Resolve a lookup alias to a prekey bundle. Returns (bundle, agent_id).""" with self._lock: self._refresh_from_shared_relay() self._prune_stale_metadata() resolved_id = self._resolve_prekey_lookup_alias(lookup_token) if not resolved_id: return None, "" stored = self._prekey_bundles.get(resolved_id) if not stored: return None, "" return dict(stored), resolved_id def register_prekey_lookup_alias( self, alias: str, agent_id: str, *, expires_at: int = 0, max_uses: int = 0, use_count: int = 0, last_used_at: int = 0, ) -> None: """Register a lookup alias for an agent's prekey bundle.""" handle = str(alias or "").strip() aid = str(agent_id or "").strip() if handle and aid: with self._lock: self._refresh_from_shared_relay() self._prekey_lookup_aliases[handle] = self._make_prekey_lookup_alias_record( aid, expires_at=expires_at, max_uses=max_uses, use_count=use_count, last_used_at=last_used_at, ) self._save() def consume_one_time_prekey(self, agent_id: str) -> dict[str, Any] | None: """Atomically claim the next published one-time prekey for a peer bundle.""" claimed: dict[str, Any] | None = None with self._lock: self._refresh_from_shared_relay() stored = self._prekey_bundles.get(agent_id) if not stored: return None bundle = dict(stored.get("bundle") or {}) otks = list(bundle.get("one_time_prekeys") or []) if not otks: return dict(stored) claimed = dict(otks.pop(0) or {}) bundle["one_time_prekeys"] = otks bundle["one_time_prekey_count"] = len(otks) stored = dict(stored) stored["bundle"] = bundle stored["updated_at"] = int(time.time()) self._prekey_bundles[agent_id] = stored self._save() result = dict(stored) result["claimed_one_time_prekey"] = claimed return result def _prune_witnesses(self, target_id: str) -> None: cutoff = time.time() - self._witness_ttl_seconds() self._witnesses[target_id] = [ w for w in self._witnesses.get(target_id, []) if float(w.get("timestamp", 0)) >= cutoff ] if not self._witnesses[target_id]: del self._witnesses[target_id] def record_witness( self, witness_id: str, target_id: str, dh_pub_key: str, timestamp: int, ) -> tuple[bool, str]: if not witness_id or not target_id or not dh_pub_key: return False, "Missing witness_id, target_id, or dh_pub_key" if witness_id == target_id: return False, "Cannot witness yourself" with self._lock: self._refresh_from_shared_relay() self._prune_witnesses(target_id) entries = self._witnesses.get(target_id, []) for entry in entries: if entry.get("witness_id") == witness_id and entry.get("dh_pub_key") == dh_pub_key: return False, "Duplicate witness" entries.append( { "witness_id": witness_id, "dh_pub_key": dh_pub_key, "timestamp": int(timestamp), } ) self._witnesses[target_id] = entries[-50:] self._save() return True, "ok" def get_witnesses(self, target_id: str, dh_pub_key: str | None = None, limit: int = 5) -> list[dict]: with self._lock: self._refresh_from_shared_relay() self._prune_witnesses(target_id) entries = list(self._witnesses.get(target_id, [])) if dh_pub_key: entries = [e for e in entries if e.get("dh_pub_key") == dh_pub_key] entries = sorted(entries, key=lambda e: e.get("timestamp", 0), reverse=True) return entries[: max(1, limit)] def _epoch_bucket(self, ts: float | None = None) -> int: now = ts if ts is not None else time.time() return int(now // EPOCH_SECONDS) def _mailbox_limit_for_class(self, delivery_class: str) -> int: if delivery_class == "request": return self._request_mailbox_limit() if delivery_class == "shared": return self._shared_mailbox_limit() return self._self_mailbox_limit() def _cleanup_expired(self) -> bool: now = time.time() changed = False for mailbox_id in list(self._mailboxes): fresh = [m for m in self._mailboxes[mailbox_id] if now - m.timestamp < TTL_SECONDS] if len(fresh) != len(self._mailboxes[mailbox_id]): changed = True self._mailboxes[mailbox_id] = fresh if not self._mailboxes[mailbox_id]: del self._mailboxes[mailbox_id] changed = True self._stats["messages_in_memory"] = sum(len(v) for v in self._mailboxes.values()) return changed def _total_nonce_count(self) -> int: return sum(len(c) for c in self._nonce_caches.values()) def _trim_global_nonce_budget(self, *, preferred_agent_id: str = "") -> int: trimmed = 0 max_entries = self._nonce_cache_max_entries() preferred_agent_id = str(preferred_agent_id or "").strip() while self._total_nonce_count() >= max_entries: oldest_choice: tuple[str, str, float] | None = None for aid, cache in self._nonce_caches.items(): if not cache: continue if preferred_agent_id and aid == preferred_agent_id and len(self._nonce_caches) > 1: continue nonce_value, expiry = next(iter(cache.items())) if oldest_choice is None or float(expiry) < oldest_choice[2]: oldest_choice = (aid, nonce_value, float(expiry)) if oldest_choice is None and preferred_agent_id: for aid, cache in self._nonce_caches.items(): if not cache: continue nonce_value, expiry = next(iter(cache.items())) if oldest_choice is None or float(expiry) < oldest_choice[2]: oldest_choice = (aid, nonce_value, float(expiry)) if oldest_choice is None: break aid, nonce_value, _expiry = oldest_choice cache = self._nonce_caches.get(aid) if not cache: continue cache.pop(nonce_value, None) if not cache: self._nonce_caches.pop(aid, None) trimmed += 1 if trimmed: metrics_inc("dm_nonce_cache_trimmed") return trimmed def consume_nonce(self, agent_id: str, nonce: str, timestamp: int) -> tuple[bool, str]: nonce = str(nonce or "").strip() if not nonce: return False, "Missing nonce" agent_id = str(agent_id or "").strip() now = time.time() with self._lock: self._refresh_from_shared_relay() # Expire stale entries across all agents for aid in list(self._nonce_caches): cache = self._nonce_caches[aid] self._nonce_caches[aid] = OrderedDict( (k, exp) for k, exp in cache.items() if float(exp) > now ) if not self._nonce_caches[aid]: del self._nonce_caches[aid] agent_cache = self._nonce_caches.get(agent_id) # Replay check if agent_cache and nonce in agent_cache: metrics_inc("dm_nonce_replay") return False, "nonce replay detected" # Per-agent capacity check per_agent_max = self._nonce_per_agent_max() if agent_cache and len(agent_cache) >= per_agent_max: metrics_inc("dm_nonce_cache_full") return False, "nonce cache at capacity" # Global capacity is a soft memory bound. Trim the oldest nonce # entries first so one busy agent cannot turn the global budget # into a cross-agent availability choke point. if self._total_nonce_count() >= self._nonce_cache_max_entries(): self._trim_global_nonce_budget(preferred_agent_id=agent_id) expiry = max(now + self._nonce_ttl_seconds(), float(timestamp) + self._nonce_ttl_seconds()) if agent_cache is None: agent_cache = OrderedDict() self._nonce_caches[agent_id] = agent_cache agent_cache[nonce] = expiry agent_cache.move_to_end(nonce) self._save() return True, "ok" def deposit( self, *, sender_id: str, raw_sender_id: str = "", recipient_id: str = "", ciphertext: str, msg_id: str = "", delivery_class: str, recipient_token: str | None = None, sender_seal: str = "", relay_salt: str = "", sender_token_hash: str = "", payload_format: str = "dm1", session_welcome: str = "", ) -> dict[str, Any]: with self._lock: self._refresh_from_shared_relay() authority_sender = str(raw_sender_id or sender_id or "").strip() sender_block_ref = self._sender_block_ref( authority_sender, scope=self._sender_block_scope( recipient_id=recipient_id, recipient_token=str(recipient_token or ""), delivery_class=delivery_class, ), ) blocked_refs = self._sender_block_refs( authority_sender, recipient_id=recipient_id, recipient_token=str(recipient_token or ""), delivery_class=delivery_class, ) if recipient_id and any(ref in self._blocks.get(recipient_id, set()) for ref in blocked_refs): metrics_inc("dm_drop_blocked") return {"ok": False, "detail": "Recipient is not accepting your messages"} if len(ciphertext) > int(self._settings().MESH_DM_MAX_MSG_BYTES): metrics_inc("dm_drop_oversize") return { "ok": False, "detail": f"Message too large ({len(ciphertext)} > {int(self._settings().MESH_DM_MAX_MSG_BYTES)})", } self._cleanup_expired() if delivery_class == "request": if not sender_token_hash: return {"ok": False, "detail": "sender_token required for request delivery"} mailbox_key = self._mailbox_key("requests", recipient_id) elif delivery_class == "shared": if not recipient_token: metrics_inc("dm_claim_invalid") return {"ok": False, "detail": "recipient_token required for shared delivery"} mailbox_key = self._hashed_mailbox_token(recipient_token) else: return {"ok": False, "detail": "Unsupported delivery_class"} if len(self._mailboxes[mailbox_key]) >= self._mailbox_limit_for_class(delivery_class): metrics_inc("dm_drop_full") return {"ok": False, "detail": "Recipient mailbox full"} if not msg_id: msg_id = f"dm_{int(time.time() * 1000)}_{secrets.token_hex(6)}" elif any(m.msg_id == msg_id for m in self._mailboxes[mailbox_key]): return {"ok": True, "msg_id": msg_id} relay_sender_id = ( f"sender_token:{sender_token_hash}" if sender_token_hash else sender_id ) self._mailboxes[mailbox_key].append( DMMessage( sender_id=relay_sender_id, ciphertext=ciphertext, timestamp=time.time(), msg_id=msg_id, delivery_class=delivery_class, sender_seal=sender_seal, sender_block_ref=sender_block_ref, payload_format=str(payload_format or "dm1"), session_welcome=str(session_welcome or ""), ) ) self._stats["messages_in_memory"] = sum(len(v) for v in self._mailboxes.values()) self._save() return {"ok": True, "msg_id": msg_id} def is_blocked(self, recipient_id: str, sender_id: str) -> bool: with self._lock: self._refresh_from_shared_relay() if not recipient_id: return False blocked_refs = self._sender_block_refs( sender_id, recipient_id=recipient_id, delivery_class="request", ) return any(ref in self._blocks.get(recipient_id, set()) for ref in blocked_refs) def _collect_from_keys( self, keys: list[str], *, destructive: bool, limit: int = 0, ) -> tuple[list[dict[str, Any]], bool]: messages: list[DMMessage] = [] seen: set[str] = set() popped: dict[str, list[DMMessage]] = {} for key in keys: if destructive: raw = self._mailboxes.pop(key, []) popped[key] = raw else: raw = list(self._mailboxes.get(key, [])) for message in raw: if message.msg_id in seen: continue seen.add(message.msg_id) messages.append(message) sorted_messages = sorted(messages, key=lambda item: item.timestamp) has_more = False if limit > 0 and len(sorted_messages) > limit: has_more = True kept = sorted_messages[:limit] if destructive: kept_ids = {m.msg_id for m in kept} for key, original in popped.items(): remaining = [m for m in original if m.msg_id not in kept_ids] if remaining: self._mailboxes.setdefault(key, []).extend(remaining) sorted_messages = kept if destructive: self._stats["messages_in_memory"] = sum(len(v) for v in self._mailboxes.values()) self._save() result = [ { "sender_id": message.sender_id, "ciphertext": message.ciphertext, "timestamp": message.timestamp, "msg_id": message.msg_id, "delivery_class": message.delivery_class, "sender_seal": message.sender_seal, "format": message.payload_format, "session_welcome": message.session_welcome, } for message in sorted_messages ] return result, has_more def collect_claims( self, agent_id: str, claims: list[dict[str, Any]], *, limit: int = 0, ) -> tuple[list[dict[str, Any]], bool]: with self._lock: self._refresh_from_shared_relay() self._cleanup_expired() keys: list[str] = [] for claim in claims[:32]: keys.extend(self._mailbox_keys_for_claim(agent_id, claim)) return self._collect_from_keys(list(dict.fromkeys(keys)), destructive=True, limit=limit) def count_claims(self, agent_id: str, claims: list[dict[str, Any]]) -> int: with self._lock: self._refresh_from_shared_relay() self._cleanup_expired() keys: list[str] = [] for claim in claims[:32]: keys.extend(self._mailbox_keys_for_claim(agent_id, claim)) messages, _ = self._collect_from_keys(list(dict.fromkeys(keys)), destructive=False) return len(messages) def claim_message_ids(self, agent_id: str, claims: list[dict[str, Any]]) -> set[str]: with self._lock: self._refresh_from_shared_relay() self._cleanup_expired() keys: list[str] = [] for claim in claims[:32]: keys.extend(self._mailbox_keys_for_claim(agent_id, claim)) messages, _ = self._collect_from_keys(list(dict.fromkeys(keys)), destructive=False) return { str(message.get("msg_id", "") or "") for message in messages if str(message.get("msg_id", "") or "") } def collect_legacy( self, agent_id: str | None = None, agent_token: str | None = None, *, limit: int = 0, ) -> tuple[list[dict[str, Any]], bool]: with self._lock: self._refresh_from_shared_relay() self._cleanup_expired() if not agent_token: return [], False keys = [self._pepper_token(agent_token), agent_token] return self._collect_from_keys(list(dict.fromkeys(keys)), destructive=True, limit=limit) def count_legacy(self, agent_id: str | None = None, agent_token: str | None = None) -> int: with self._lock: self._refresh_from_shared_relay() self._cleanup_expired() if not agent_token: return 0 keys = [self._pepper_token(agent_token), agent_token] messages, _ = self._collect_from_keys(list(dict.fromkeys(keys)), destructive=False) return len(messages) def block(self, agent_id: str, blocked_id: str) -> None: with self._lock: self._refresh_from_shared_relay() blocked_ref = self._canonical_blocked_id( blocked_id, scope=self._sender_block_scope(recipient_id=agent_id, delivery_class="request"), ) if not blocked_ref: return self._blocks[agent_id].add(blocked_ref) blocked_refs = {blocked_ref} blocked_label = str(blocked_id or "").strip() if blocked_label and not blocked_label.startswith("ref:"): blocked_refs.add(self._legacy_sender_block_ref(blocked_label)) purge_keys = self._legacy_token_candidates(agent_id) bound_request = self._bound_mailbox_key(agent_id, "requests") bound_self = self._bound_mailbox_key(agent_id, "self") if bound_request: purge_keys.append(bound_request) if bound_self: purge_keys.append(bound_self) purge_keys.extend( [ self._mailbox_key("self", agent_id), self._mailbox_key("requests", agent_id), self._mailbox_key("self", agent_id, self._epoch_bucket() - 1), self._mailbox_key("requests", agent_id, self._epoch_bucket() - 1), ] ) for key in set(purge_keys): if key in self._mailboxes: self._mailboxes[key] = [ m for m in self._mailboxes[key] if self._message_block_ref(m) not in blocked_refs ] self._stats["messages_in_memory"] = sum(len(v) for v in self._mailboxes.values()) self._save() def unblock(self, agent_id: str, blocked_id: str) -> None: with self._lock: self._refresh_from_shared_relay() blocked_ref = self._canonical_blocked_id( blocked_id, scope=self._sender_block_scope(recipient_id=agent_id, delivery_class="request"), ) if not blocked_ref: return self._blocks[agent_id].discard(blocked_ref) blocked_label = str(blocked_id or "").strip() if blocked_label and not blocked_label.startswith("ref:"): self._blocks[agent_id].discard(self._legacy_sender_block_ref(blocked_label)) self._save() dm_relay = DMRelay()