diff --git a/docs/development.md b/docs/development.md index 6819750..6a15114 100644 --- a/docs/development.md +++ b/docs/development.md @@ -21,6 +21,24 @@ make check Run these tests before making new commits or opening pull requests. +## Module dependencies + +Modules can require other modules to run first by declaring their classes in +`dependencies`. The command runner uses a stable topological ordering, so the +existing module list order is preserved wherever dependency constraints allow. + +```python +class DependentModule(MVTModule): + dependencies = (PrerequisiteModule,) + + def run(self): + prerequisite_results = self.get_dependency_results(PrerequisiteModule) +``` + +Selecting a single module also runs its transitive dependencies. If a dependency +is unavailable or the dependency graph contains a cycle, the command logs a +warning and does not run any modules. + ## Profiling Some MVT modules extract and process significant amounts of data during the analysis process or while checking results against known indicators. Care must be diff --git a/src/mvt/common/command.py b/src/mvt/common/command.py index 03d2b00..84b2148 100644 --- a/src/mvt/common/command.py +++ b/src/mvt/common/command.py @@ -8,6 +8,7 @@ import logging import os import sys from datetime import datetime +from heapq import heappop, heappush from typing import Any, Optional from rich.console import Console @@ -18,6 +19,7 @@ from .alerts import AlertLevel, AlertStore from .config import settings from .indicators import Indicators from .module import EncryptedBackupError, MVTModule, run_module, save_timeline +from .module_types import ModuleTimeline from .utils import ( CustomJSONEncoder, convert_datetime_to_iso, @@ -44,7 +46,7 @@ class Command: disable_indicator_check: bool = False, ) -> None: self.name = "" - self.modules: list[Any] = [] + self.modules: list[type[MVTModule]] = [] self.target_path = target_path self.results_path = results_path @@ -63,10 +65,10 @@ class Command: # This list will contain all executed modules. # We can use this to reference e.g. self.executed[0].results. - self.executed: list[Any] = [] + self.executed: list[MVTModule] = [] self.hashes = hashes self.hash_values: list[dict[str, Any]] = [] - self.timeline: list[dict[str, Any]] = [] + self.timeline: ModuleTimeline = [] # Load IOCs self._create_storage() @@ -258,20 +260,84 @@ class Command: console.print("") console.print(panel) + def _ordered_modules(self) -> Optional[list[type[MVTModule]]]: + """Return enabled modules in stable topological order.""" + module_indexes = {module: index for index, module in enumerate(self.modules)} + + if self.module_name: + selected = [ + module for module in self.modules if module.__name__ == self.module_name + ] + else: + selected = [module for module in self.modules if module.enabled] + + required = set(selected) + pending = list(selected) + while pending: + module = pending.pop() + for dependency in module.dependencies: + if dependency not in module_indexes: + self.log.warning( + "Module %s depends on unavailable module %s. " + "No modules will be run.", + module.__name__, + dependency.__name__, + ) + return None + if dependency not in required: + required.add(dependency) + pending.append(dependency) + + dependents: dict[type[MVTModule], list[type[MVTModule]]] = { + module: [] for module in required + } + indegree = {module: 0 for module in required} + for module in required: + for dependency in module.dependencies: + if dependency not in required: + continue + dependents[dependency].append(module) + indegree[module] += 1 + + ready: list[tuple[int, type[MVTModule]]] = [] + for module, count in indegree.items(): + if count == 0: + heappush(ready, (module_indexes[module], module)) + + ordered = [] + while ready: + _, module = heappop(ready) + ordered.append(module) + for dependent in dependents[module]: + indegree[dependent] -= 1 + if indegree[dependent] == 0: + heappush(ready, (module_indexes[dependent], dependent)) + + if len(ordered) != len(required): + cyclic_modules = sorted( + (module.__name__ for module, count in indegree.items() if count > 0) + ) + self.log.warning( + "Circular module dependency detected involving: %s. " + "No modules will be run.", + ", ".join(cyclic_modules), + ) + return None + + return ordered + def run(self) -> None: + ordered_modules = self._ordered_modules() + if ordered_modules is None: + return + try: self.init() except NotImplementedError: pass - for module in self.modules: - if self.module_name and module.__name__ != self.module_name: - continue - - if not module.enabled and not ( - self.module_name and module.__name__ == self.module_name - ): - continue + executed_by_type: dict[type[MVTModule], MVTModule] = {} + for module in ordered_modules: # FIXME: do we need the logger here module_logger = logging.getLogger(module.__module__) @@ -282,6 +348,10 @@ class Command: module_options=self.module_options, log=module_logger, ) + m.dependency_modules = { + dependency: executed_by_type[dependency] + for dependency in module.dependencies + } if self.iocs.total_ioc_count: m.indicators = self.iocs @@ -305,6 +375,7 @@ class Command: return self.executed.append(m) + executed_by_type[module] = m self.timeline.extend(m.timeline) self.alertstore.extend(m.alertstore.alerts) diff --git a/src/mvt/common/module.py b/src/mvt/common/module.py index 1f6f6ce..b243874 100644 --- a/src/mvt/common/module.py +++ b/src/mvt/common/module.py @@ -9,7 +9,7 @@ import logging import os import re from dataclasses import asdict, is_dataclass -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Sequence from .alerts import AlertStore from .indicators import Indicators @@ -43,6 +43,7 @@ class MVTModule: enabled: bool = True slug: Optional[str] = None + dependencies: Sequence[type["MVTModule"]] = () def __init__( self, @@ -71,6 +72,7 @@ class MVTModule: self.file_path: Optional[str] = file_path self.target_path: Optional[str] = target_path self.results_path: Optional[str] = results_path + self.serial: Optional[str] = None self.module_options: Optional[Dict[str, Any]] = ( module_options if module_options else {} ) @@ -81,6 +83,13 @@ class MVTModule: self.results: ModuleResults = results if results else [] self.timeline: ModuleTimeline = [] + self.dependency_modules: Dict[type["MVTModule"], "MVTModule"] = {} + + def get_dependency_results( + self, module_class: type["MVTModule"] + ) -> ModuleResults: + """Return the results produced by a prerequisite module.""" + return self.dependency_modules[module_class].results @classmethod def from_json(cls, json_path: str, log: logging.Logger): diff --git a/tests/common/test_command.py b/tests/common/test_command.py index fcd62a0..64413d9 100644 --- a/tests/common/test_command.py +++ b/tests/common/test_command.py @@ -4,11 +4,59 @@ # https://license.mvt.re/1.1/ import json +import logging from mvt.common.command import Command +from mvt.common.module import MVTModule + + +class RecordingModule(MVTModule): + run_order = [] + + def run(self): + self.run_order.append(self.__class__.__name__) + + def check_indicators(self): + pass + + +class FirstModule(RecordingModule): + def run(self): + super().run() + self.results = ["first"] + + +class SecondModule(RecordingModule): + dependencies = (FirstModule,) + + def run(self): + super().run() + self.results = self.get_dependency_results(FirstModule) + ["second"] + + +class ThirdModule(RecordingModule): + dependencies = (SecondModule,) + + +class IndependentModule(RecordingModule): + pass + + +class RecordingCommand(Command): + def init(self): + self.initialized = True + + def module_init(self, module): + pass + + def finish(self): + pass class TestCommand: + def setup_method(self): + RecordingModule.run_order = [] + def test_store_alerts_handles_bytes(self, tmp_path): cmd = Command(results_path=str(tmp_path)) cmd.alertstore.medium( @@ -21,3 +69,66 @@ class TestCommand: alerts = json.loads((tmp_path / "alerts.json").read_text()) assert alerts[0]["event"]["payload"] == "\\xa8\\xa9" + + def test_modules_run_in_stable_topological_order(self): + cmd = RecordingCommand() + cmd.modules = [ThirdModule, IndependentModule, SecondModule, FirstModule] + + cmd.run() + + assert RecordingModule.run_order == [ + "IndependentModule", + "FirstModule", + "SecondModule", + "ThirdModule", + ] + second = next(module for module in cmd.executed if isinstance(module, SecondModule)) + assert second.results == ["first", "second"] + + def test_selected_module_runs_transitive_dependencies(self): + cmd = RecordingCommand(module_name="ThirdModule") + cmd.modules = [ThirdModule, SecondModule, FirstModule, IndependentModule] + + cmd.run() + + assert RecordingModule.run_order == [ + "FirstModule", + "SecondModule", + "ThirdModule", + ] + + def test_circular_dependency_warns_and_stops(self, caplog): + class CircularOne(RecordingModule): + pass + + class CircularTwo(RecordingModule): + dependencies = (CircularOne,) + + CircularOne.dependencies = (CircularTwo,) + + cmd = RecordingCommand() + cmd.modules = [CircularOne, CircularTwo] + + with caplog.at_level(logging.WARNING): + cmd.run() + + assert RecordingModule.run_order == [] + assert not hasattr(cmd, "initialized") + assert "Circular module dependency detected" in caplog.text + + def test_unavailable_dependency_warns_and_stops(self, caplog): + class UnavailableModule(RecordingModule): + pass + + class DependentModule(RecordingModule): + dependencies = (UnavailableModule,) + + cmd = RecordingCommand() + cmd.modules = [DependentModule] + + with caplog.at_level(logging.WARNING): + cmd.run() + + assert RecordingModule.run_order == [] + assert not hasattr(cmd, "initialized") + assert "depends on unavailable module UnavailableModule" in caplog.text