mirror of
https://github.com/mvt-project/mvt.git
synced 2026-06-11 09:17:49 +02:00
Add dependency-aware module ordering
This commit is contained in:
@@ -21,6 +21,24 @@ make check
|
||||
|
||||
Run these tests before making new commits or opening pull requests.
|
||||
|
||||
## Module dependencies
|
||||
|
||||
Modules can require other modules to run first by declaring their classes in
|
||||
`dependencies`. The command runner uses a stable topological ordering, so the
|
||||
existing module list order is preserved wherever dependency constraints allow.
|
||||
|
||||
```python
|
||||
class DependentModule(MVTModule):
|
||||
dependencies = (PrerequisiteModule,)
|
||||
|
||||
def run(self):
|
||||
prerequisite_results = self.get_dependency_results(PrerequisiteModule)
|
||||
```
|
||||
|
||||
Selecting a single module also runs its transitive dependencies. If a dependency
|
||||
is unavailable or the dependency graph contains a cycle, the command logs a
|
||||
warning and does not run any modules.
|
||||
|
||||
## Profiling
|
||||
|
||||
Some MVT modules extract and process significant amounts of data during the analysis process or while checking results against known indicators. Care must be
|
||||
|
||||
+82
-11
@@ -8,6 +8,7 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from heapq import heappop, heappush
|
||||
from typing import Any, Optional
|
||||
|
||||
from rich.console import Console
|
||||
@@ -18,6 +19,7 @@ from .alerts import AlertLevel, AlertStore
|
||||
from .config import settings
|
||||
from .indicators import Indicators
|
||||
from .module import EncryptedBackupError, MVTModule, run_module, save_timeline
|
||||
from .module_types import ModuleTimeline
|
||||
from .utils import (
|
||||
CustomJSONEncoder,
|
||||
convert_datetime_to_iso,
|
||||
@@ -44,7 +46,7 @@ class Command:
|
||||
disable_indicator_check: bool = False,
|
||||
) -> None:
|
||||
self.name = ""
|
||||
self.modules: list[Any] = []
|
||||
self.modules: list[type[MVTModule]] = []
|
||||
|
||||
self.target_path = target_path
|
||||
self.results_path = results_path
|
||||
@@ -63,10 +65,10 @@ class Command:
|
||||
|
||||
# This list will contain all executed modules.
|
||||
# We can use this to reference e.g. self.executed[0].results.
|
||||
self.executed: list[Any] = []
|
||||
self.executed: list[MVTModule] = []
|
||||
self.hashes = hashes
|
||||
self.hash_values: list[dict[str, Any]] = []
|
||||
self.timeline: list[dict[str, Any]] = []
|
||||
self.timeline: ModuleTimeline = []
|
||||
|
||||
# Load IOCs
|
||||
self._create_storage()
|
||||
@@ -258,20 +260,84 @@ class Command:
|
||||
console.print("")
|
||||
console.print(panel)
|
||||
|
||||
def _ordered_modules(self) -> Optional[list[type[MVTModule]]]:
|
||||
"""Return enabled modules in stable topological order."""
|
||||
module_indexes = {module: index for index, module in enumerate(self.modules)}
|
||||
|
||||
if self.module_name:
|
||||
selected = [
|
||||
module for module in self.modules if module.__name__ == self.module_name
|
||||
]
|
||||
else:
|
||||
selected = [module for module in self.modules if module.enabled]
|
||||
|
||||
required = set(selected)
|
||||
pending = list(selected)
|
||||
while pending:
|
||||
module = pending.pop()
|
||||
for dependency in module.dependencies:
|
||||
if dependency not in module_indexes:
|
||||
self.log.warning(
|
||||
"Module %s depends on unavailable module %s. "
|
||||
"No modules will be run.",
|
||||
module.__name__,
|
||||
dependency.__name__,
|
||||
)
|
||||
return None
|
||||
if dependency not in required:
|
||||
required.add(dependency)
|
||||
pending.append(dependency)
|
||||
|
||||
dependents: dict[type[MVTModule], list[type[MVTModule]]] = {
|
||||
module: [] for module in required
|
||||
}
|
||||
indegree = {module: 0 for module in required}
|
||||
for module in required:
|
||||
for dependency in module.dependencies:
|
||||
if dependency not in required:
|
||||
continue
|
||||
dependents[dependency].append(module)
|
||||
indegree[module] += 1
|
||||
|
||||
ready: list[tuple[int, type[MVTModule]]] = []
|
||||
for module, count in indegree.items():
|
||||
if count == 0:
|
||||
heappush(ready, (module_indexes[module], module))
|
||||
|
||||
ordered = []
|
||||
while ready:
|
||||
_, module = heappop(ready)
|
||||
ordered.append(module)
|
||||
for dependent in dependents[module]:
|
||||
indegree[dependent] -= 1
|
||||
if indegree[dependent] == 0:
|
||||
heappush(ready, (module_indexes[dependent], dependent))
|
||||
|
||||
if len(ordered) != len(required):
|
||||
cyclic_modules = sorted(
|
||||
(module.__name__ for module, count in indegree.items() if count > 0)
|
||||
)
|
||||
self.log.warning(
|
||||
"Circular module dependency detected involving: %s. "
|
||||
"No modules will be run.",
|
||||
", ".join(cyclic_modules),
|
||||
)
|
||||
return None
|
||||
|
||||
return ordered
|
||||
|
||||
def run(self) -> None:
|
||||
ordered_modules = self._ordered_modules()
|
||||
if ordered_modules is None:
|
||||
return
|
||||
|
||||
try:
|
||||
self.init()
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
for module in self.modules:
|
||||
if self.module_name and module.__name__ != self.module_name:
|
||||
continue
|
||||
|
||||
if not module.enabled and not (
|
||||
self.module_name and module.__name__ == self.module_name
|
||||
):
|
||||
continue
|
||||
executed_by_type: dict[type[MVTModule], MVTModule] = {}
|
||||
for module in ordered_modules:
|
||||
|
||||
# FIXME: do we need the logger here
|
||||
module_logger = logging.getLogger(module.__module__)
|
||||
@@ -282,6 +348,10 @@ class Command:
|
||||
module_options=self.module_options,
|
||||
log=module_logger,
|
||||
)
|
||||
m.dependency_modules = {
|
||||
dependency: executed_by_type[dependency]
|
||||
for dependency in module.dependencies
|
||||
}
|
||||
|
||||
if self.iocs.total_ioc_count:
|
||||
m.indicators = self.iocs
|
||||
@@ -305,6 +375,7 @@ class Command:
|
||||
return
|
||||
|
||||
self.executed.append(m)
|
||||
executed_by_type[module] = m
|
||||
self.timeline.extend(m.timeline)
|
||||
self.alertstore.extend(m.alertstore.alerts)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import logging
|
||||
import os
|
||||
import re
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, Sequence
|
||||
|
||||
from .alerts import AlertStore
|
||||
from .indicators import Indicators
|
||||
@@ -43,6 +43,7 @@ class MVTModule:
|
||||
|
||||
enabled: bool = True
|
||||
slug: Optional[str] = None
|
||||
dependencies: Sequence[type["MVTModule"]] = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -71,6 +72,7 @@ class MVTModule:
|
||||
self.file_path: Optional[str] = file_path
|
||||
self.target_path: Optional[str] = target_path
|
||||
self.results_path: Optional[str] = results_path
|
||||
self.serial: Optional[str] = None
|
||||
self.module_options: Optional[Dict[str, Any]] = (
|
||||
module_options if module_options else {}
|
||||
)
|
||||
@@ -81,6 +83,13 @@ class MVTModule:
|
||||
|
||||
self.results: ModuleResults = results if results else []
|
||||
self.timeline: ModuleTimeline = []
|
||||
self.dependency_modules: Dict[type["MVTModule"], "MVTModule"] = {}
|
||||
|
||||
def get_dependency_results(
|
||||
self, module_class: type["MVTModule"]
|
||||
) -> ModuleResults:
|
||||
"""Return the results produced by a prerequisite module."""
|
||||
return self.dependency_modules[module_class].results
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_path: str, log: logging.Logger):
|
||||
|
||||
@@ -4,11 +4,59 @@
|
||||
# https://license.mvt.re/1.1/
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from mvt.common.command import Command
|
||||
from mvt.common.module import MVTModule
|
||||
|
||||
|
||||
class RecordingModule(MVTModule):
|
||||
run_order = []
|
||||
|
||||
def run(self):
|
||||
self.run_order.append(self.__class__.__name__)
|
||||
|
||||
def check_indicators(self):
|
||||
pass
|
||||
|
||||
|
||||
class FirstModule(RecordingModule):
|
||||
def run(self):
|
||||
super().run()
|
||||
self.results = ["first"]
|
||||
|
||||
|
||||
class SecondModule(RecordingModule):
|
||||
dependencies = (FirstModule,)
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
self.results = self.get_dependency_results(FirstModule) + ["second"]
|
||||
|
||||
|
||||
class ThirdModule(RecordingModule):
|
||||
dependencies = (SecondModule,)
|
||||
|
||||
|
||||
class IndependentModule(RecordingModule):
|
||||
pass
|
||||
|
||||
|
||||
class RecordingCommand(Command):
|
||||
def init(self):
|
||||
self.initialized = True
|
||||
|
||||
def module_init(self, module):
|
||||
pass
|
||||
|
||||
def finish(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestCommand:
|
||||
def setup_method(self):
|
||||
RecordingModule.run_order = []
|
||||
|
||||
def test_store_alerts_handles_bytes(self, tmp_path):
|
||||
cmd = Command(results_path=str(tmp_path))
|
||||
cmd.alertstore.medium(
|
||||
@@ -21,3 +69,66 @@ class TestCommand:
|
||||
|
||||
alerts = json.loads((tmp_path / "alerts.json").read_text())
|
||||
assert alerts[0]["event"]["payload"] == "\\xa8\\xa9"
|
||||
|
||||
def test_modules_run_in_stable_topological_order(self):
|
||||
cmd = RecordingCommand()
|
||||
cmd.modules = [ThirdModule, IndependentModule, SecondModule, FirstModule]
|
||||
|
||||
cmd.run()
|
||||
|
||||
assert RecordingModule.run_order == [
|
||||
"IndependentModule",
|
||||
"FirstModule",
|
||||
"SecondModule",
|
||||
"ThirdModule",
|
||||
]
|
||||
second = next(module for module in cmd.executed if isinstance(module, SecondModule))
|
||||
assert second.results == ["first", "second"]
|
||||
|
||||
def test_selected_module_runs_transitive_dependencies(self):
|
||||
cmd = RecordingCommand(module_name="ThirdModule")
|
||||
cmd.modules = [ThirdModule, SecondModule, FirstModule, IndependentModule]
|
||||
|
||||
cmd.run()
|
||||
|
||||
assert RecordingModule.run_order == [
|
||||
"FirstModule",
|
||||
"SecondModule",
|
||||
"ThirdModule",
|
||||
]
|
||||
|
||||
def test_circular_dependency_warns_and_stops(self, caplog):
|
||||
class CircularOne(RecordingModule):
|
||||
pass
|
||||
|
||||
class CircularTwo(RecordingModule):
|
||||
dependencies = (CircularOne,)
|
||||
|
||||
CircularOne.dependencies = (CircularTwo,)
|
||||
|
||||
cmd = RecordingCommand()
|
||||
cmd.modules = [CircularOne, CircularTwo]
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
cmd.run()
|
||||
|
||||
assert RecordingModule.run_order == []
|
||||
assert not hasattr(cmd, "initialized")
|
||||
assert "Circular module dependency detected" in caplog.text
|
||||
|
||||
def test_unavailable_dependency_warns_and_stops(self, caplog):
|
||||
class UnavailableModule(RecordingModule):
|
||||
pass
|
||||
|
||||
class DependentModule(RecordingModule):
|
||||
dependencies = (UnavailableModule,)
|
||||
|
||||
cmd = RecordingCommand()
|
||||
cmd.modules = [DependentModule]
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
cmd.run()
|
||||
|
||||
assert RecordingModule.run_order == []
|
||||
assert not hasattr(cmd, "initialized")
|
||||
assert "depends on unavailable module UnavailableModule" in caplog.text
|
||||
|
||||
Reference in New Issue
Block a user