Add dependency-aware module ordering

This commit is contained in:
Janik Besendorf
2026-06-05 19:46:38 +02:00
parent 11d06a3a16
commit 4530d496d3
4 changed files with 221 additions and 12 deletions
+18
View File
@@ -21,6 +21,24 @@ make check
Run these tests before making new commits or opening pull requests.
## Module dependencies
Modules can require other modules to run first by declaring their classes in
`dependencies`. The command runner uses a stable topological ordering, so the
existing module list order is preserved wherever dependency constraints allow.
```python
class DependentModule(MVTModule):
dependencies = (PrerequisiteModule,)
def run(self):
prerequisite_results = self.get_dependency_results(PrerequisiteModule)
```
Selecting a single module also runs its transitive dependencies. If a dependency
is unavailable or the dependency graph contains a cycle, the command logs a
warning and does not run any modules.
## Profiling
Some MVT modules extract and process significant amounts of data during the analysis process or while checking results against known indicators. Care must be
+82 -11
View File
@@ -8,6 +8,7 @@ import logging
import os
import sys
from datetime import datetime
from heapq import heappop, heappush
from typing import Any, Optional
from rich.console import Console
@@ -18,6 +19,7 @@ from .alerts import AlertLevel, AlertStore
from .config import settings
from .indicators import Indicators
from .module import EncryptedBackupError, MVTModule, run_module, save_timeline
from .module_types import ModuleTimeline
from .utils import (
CustomJSONEncoder,
convert_datetime_to_iso,
@@ -44,7 +46,7 @@ class Command:
disable_indicator_check: bool = False,
) -> None:
self.name = ""
self.modules: list[Any] = []
self.modules: list[type[MVTModule]] = []
self.target_path = target_path
self.results_path = results_path
@@ -63,10 +65,10 @@ class Command:
# This list will contain all executed modules.
# We can use this to reference e.g. self.executed[0].results.
self.executed: list[Any] = []
self.executed: list[MVTModule] = []
self.hashes = hashes
self.hash_values: list[dict[str, Any]] = []
self.timeline: list[dict[str, Any]] = []
self.timeline: ModuleTimeline = []
# Load IOCs
self._create_storage()
@@ -258,20 +260,84 @@ class Command:
console.print("")
console.print(panel)
def _ordered_modules(self) -> Optional[list[type[MVTModule]]]:
"""Return enabled modules in stable topological order."""
module_indexes = {module: index for index, module in enumerate(self.modules)}
if self.module_name:
selected = [
module for module in self.modules if module.__name__ == self.module_name
]
else:
selected = [module for module in self.modules if module.enabled]
required = set(selected)
pending = list(selected)
while pending:
module = pending.pop()
for dependency in module.dependencies:
if dependency not in module_indexes:
self.log.warning(
"Module %s depends on unavailable module %s. "
"No modules will be run.",
module.__name__,
dependency.__name__,
)
return None
if dependency not in required:
required.add(dependency)
pending.append(dependency)
dependents: dict[type[MVTModule], list[type[MVTModule]]] = {
module: [] for module in required
}
indegree = {module: 0 for module in required}
for module in required:
for dependency in module.dependencies:
if dependency not in required:
continue
dependents[dependency].append(module)
indegree[module] += 1
ready: list[tuple[int, type[MVTModule]]] = []
for module, count in indegree.items():
if count == 0:
heappush(ready, (module_indexes[module], module))
ordered = []
while ready:
_, module = heappop(ready)
ordered.append(module)
for dependent in dependents[module]:
indegree[dependent] -= 1
if indegree[dependent] == 0:
heappush(ready, (module_indexes[dependent], dependent))
if len(ordered) != len(required):
cyclic_modules = sorted(
(module.__name__ for module, count in indegree.items() if count > 0)
)
self.log.warning(
"Circular module dependency detected involving: %s. "
"No modules will be run.",
", ".join(cyclic_modules),
)
return None
return ordered
def run(self) -> None:
ordered_modules = self._ordered_modules()
if ordered_modules is None:
return
try:
self.init()
except NotImplementedError:
pass
for module in self.modules:
if self.module_name and module.__name__ != self.module_name:
continue
if not module.enabled and not (
self.module_name and module.__name__ == self.module_name
):
continue
executed_by_type: dict[type[MVTModule], MVTModule] = {}
for module in ordered_modules:
# FIXME: do we need the logger here
module_logger = logging.getLogger(module.__module__)
@@ -282,6 +348,10 @@ class Command:
module_options=self.module_options,
log=module_logger,
)
m.dependency_modules = {
dependency: executed_by_type[dependency]
for dependency in module.dependencies
}
if self.iocs.total_ioc_count:
m.indicators = self.iocs
@@ -305,6 +375,7 @@ class Command:
return
self.executed.append(m)
executed_by_type[module] = m
self.timeline.extend(m.timeline)
self.alertstore.extend(m.alertstore.alerts)
+10 -1
View File
@@ -9,7 +9,7 @@ import logging
import os
import re
from dataclasses import asdict, is_dataclass
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Sequence
from .alerts import AlertStore
from .indicators import Indicators
@@ -43,6 +43,7 @@ class MVTModule:
enabled: bool = True
slug: Optional[str] = None
dependencies: Sequence[type["MVTModule"]] = ()
def __init__(
self,
@@ -71,6 +72,7 @@ class MVTModule:
self.file_path: Optional[str] = file_path
self.target_path: Optional[str] = target_path
self.results_path: Optional[str] = results_path
self.serial: Optional[str] = None
self.module_options: Optional[Dict[str, Any]] = (
module_options if module_options else {}
)
@@ -81,6 +83,13 @@ class MVTModule:
self.results: ModuleResults = results if results else []
self.timeline: ModuleTimeline = []
self.dependency_modules: Dict[type["MVTModule"], "MVTModule"] = {}
def get_dependency_results(
self, module_class: type["MVTModule"]
) -> ModuleResults:
"""Return the results produced by a prerequisite module."""
return self.dependency_modules[module_class].results
@classmethod
def from_json(cls, json_path: str, log: logging.Logger):
+111
View File
@@ -4,11 +4,59 @@
# https://license.mvt.re/1.1/
import json
import logging
from mvt.common.command import Command
from mvt.common.module import MVTModule
class RecordingModule(MVTModule):
run_order = []
def run(self):
self.run_order.append(self.__class__.__name__)
def check_indicators(self):
pass
class FirstModule(RecordingModule):
def run(self):
super().run()
self.results = ["first"]
class SecondModule(RecordingModule):
dependencies = (FirstModule,)
def run(self):
super().run()
self.results = self.get_dependency_results(FirstModule) + ["second"]
class ThirdModule(RecordingModule):
dependencies = (SecondModule,)
class IndependentModule(RecordingModule):
pass
class RecordingCommand(Command):
def init(self):
self.initialized = True
def module_init(self, module):
pass
def finish(self):
pass
class TestCommand:
def setup_method(self):
RecordingModule.run_order = []
def test_store_alerts_handles_bytes(self, tmp_path):
cmd = Command(results_path=str(tmp_path))
cmd.alertstore.medium(
@@ -21,3 +69,66 @@ class TestCommand:
alerts = json.loads((tmp_path / "alerts.json").read_text())
assert alerts[0]["event"]["payload"] == "\\xa8\\xa9"
def test_modules_run_in_stable_topological_order(self):
cmd = RecordingCommand()
cmd.modules = [ThirdModule, IndependentModule, SecondModule, FirstModule]
cmd.run()
assert RecordingModule.run_order == [
"IndependentModule",
"FirstModule",
"SecondModule",
"ThirdModule",
]
second = next(module for module in cmd.executed if isinstance(module, SecondModule))
assert second.results == ["first", "second"]
def test_selected_module_runs_transitive_dependencies(self):
cmd = RecordingCommand(module_name="ThirdModule")
cmd.modules = [ThirdModule, SecondModule, FirstModule, IndependentModule]
cmd.run()
assert RecordingModule.run_order == [
"FirstModule",
"SecondModule",
"ThirdModule",
]
def test_circular_dependency_warns_and_stops(self, caplog):
class CircularOne(RecordingModule):
pass
class CircularTwo(RecordingModule):
dependencies = (CircularOne,)
CircularOne.dependencies = (CircularTwo,)
cmd = RecordingCommand()
cmd.modules = [CircularOne, CircularTwo]
with caplog.at_level(logging.WARNING):
cmd.run()
assert RecordingModule.run_order == []
assert not hasattr(cmd, "initialized")
assert "Circular module dependency detected" in caplog.text
def test_unavailable_dependency_warns_and_stops(self, caplog):
class UnavailableModule(RecordingModule):
pass
class DependentModule(RecordingModule):
dependencies = (UnavailableModule,)
cmd = RecordingCommand()
cmd.modules = [DependentModule]
with caplog.at_level(logging.WARNING):
cmd.run()
assert RecordingModule.run_order == []
assert not hasattr(cmd, "initialized")
assert "depends on unavailable module UnavailableModule" in caplog.text