diff --git a/mvt/android/modules/adb/chrome_history.py b/mvt/android/modules/adb/chrome_history.py index 7d0a437..cdd4e6f 100644 --- a/mvt/android/modules/adb/chrome_history.py +++ b/mvt/android/modules/adb/chrome_history.py @@ -31,6 +31,7 @@ class ChromeHistory(AndroidExtraction): super().__init__(file_path=file_path, target_path=target_path, results_path=results_path, fast_mode=fast_mode, log=log, results=results) + self.results = [] def serialize(self, record: dict) -> Union[dict, list]: return { @@ -55,6 +56,7 @@ class ChromeHistory(AndroidExtraction): :param db_path: Path to the History database to process. """ + assert isinstance(self.results, list) # assert results type for mypy conn = sqlite3.connect(db_path) cur = conn.cursor() cur.execute(""" diff --git a/mvt/android/modules/adb/files.py b/mvt/android/modules/adb/files.py index 5c98c46..522ad77 100644 --- a/mvt/android/modules/adb/files.py +++ b/mvt/android/modules/adb/files.py @@ -39,7 +39,7 @@ class Files(AndroidExtraction): log=log, results=results) self.full_find = False - def serialize(self, record: dict) -> Union[dict, list]: + def serialize(self, record: dict) -> Union[dict, list, None]: if "modified_time" in record: return { "timestamp": record["modified_time"], @@ -62,6 +62,9 @@ class Files(AndroidExtraction): self.detected.append(result) def backup_file(self, file_path: str) -> None: + if not self.results_path: + return + local_file_name = file_path.replace("/", "_").replace(" ", "-") local_files_folder = os.path.join(self.results_path, "files") if not os.path.exists(local_files_folder): @@ -79,6 +82,7 @@ class Files(AndroidExtraction): file_path, local_file_path) def find_files(self, folder: str) -> None: + assert isinstance(self.results, list) if self.full_find: cmd = f"find '{folder}' -type f -printf '%T@ %m %s %u %g %p\n' 2> /dev/null" output = self._adb_command(cmd) @@ -121,8 +125,7 @@ class Files(AndroidExtraction): for entry in self.results: self.log.info("Found file in tmp folder at path %s", entry.get("path")) - if self.results_path: - self.backup_file(entry.get("path")) + self.backup_file(entry.get("path")) for media_folder in ANDROID_MEDIA_FOLDERS: self.find_files(media_folder) diff --git a/mvt/common/cmd_check_iocs.py b/mvt/common/cmd_check_iocs.py index 0277291..a654202 100644 --- a/mvt/common/cmd_check_iocs.py +++ b/mvt/common/cmd_check_iocs.py @@ -30,6 +30,7 @@ class CmdCheckIOCS(Command): self.name = "check-iocs" def run(self) -> None: + assert self.target_path is not None all_modules = [] for entry in self.modules: if entry not in all_modules: diff --git a/mvt/common/indicators.py b/mvt/common/indicators.py index 87dae68..b27f0ca 100644 --- a/mvt/common/indicators.py +++ b/mvt/common/indicators.py @@ -6,7 +6,7 @@ import json import logging import os -from typing import Optional, Union +from typing import Any, Dict, Iterator, List, Optional, Union from appdirs import user_data_dir @@ -23,7 +23,7 @@ class Indicators: def __init__(self, log=logging.Logger) -> None: self.log = log - self.ioc_collections = [] + self.ioc_collections: List[Dict[str, Any]] = [] self.total_ioc_count = 0 def _load_downloaded_indicators(self) -> None: @@ -209,7 +209,7 @@ class Indicators: self.log.info("Loaded a total of %d unique indicators", self.total_ioc_count) - def get_iocs(self, ioc_type: str) -> Union[dict, None]: + def get_iocs(self, ioc_type: str) -> Union[Iterator[Dict[str, Any]], None]: for ioc_collection in self.ioc_collections: for ioc in ioc_collection.get(ioc_type, []): yield { diff --git a/mvt/common/module.py b/mvt/common/module.py index cd35535..ee5766d 100644 --- a/mvt/common/module.py +++ b/mvt/common/module.py @@ -7,7 +7,7 @@ import csv import logging import os import re -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import simplejson as json @@ -28,7 +28,7 @@ class MVTModule: """This class provides a base for all extraction modules.""" enabled = True - slug = None + slug: Optional[str] = None def __init__( self, @@ -61,12 +61,12 @@ class MVTModule: self.log = log self.indicators = None self.results = results if results else [] - self.detected = [] - self.timeline = [] - self.timeline_detected = [] + self.detected: List[Dict[str, Any]] = [] + self.timeline: List[Dict[str, str]] = [] + self.timeline_detected: List[Dict[str, str]] = [] @classmethod - def from_json(cls, json_path: str, log: logging.Logger = None): + def from_json(cls, json_path: str, log: logging.Logger): with open(json_path, "r", encoding="utf-8") as handle: results = json.load(handle) if log: @@ -116,7 +116,7 @@ class MVTModule: with open(detected_json_path, "w", encoding="utf-8") as handle: json.dump(self.detected, handle, indent=4, default=str) - def serialize(self, record: dict) -> Union[dict, list]: + def serialize(self, record: dict) -> Union[dict, list, None]: raise NotImplementedError @staticmethod @@ -159,7 +159,7 @@ class MVTModule: raise NotImplementedError -def run_module(module: Callable) -> None: +def run_module(module: MVTModule) -> None: module.log.info("Running module %s...", module.__class__.__name__) try: diff --git a/mvt/common/updates.py b/mvt/common/updates.py index 8617249..787f61d 100644 --- a/mvt/common/updates.py +++ b/mvt/common/updates.py @@ -6,6 +6,7 @@ import logging import os from datetime import datetime +from typing import Optional, Tuple import requests import yaml @@ -83,7 +84,7 @@ class IndicatorsUpdates: with open(self.latest_update_path, "w", encoding="utf-8") as handle: handle.write(str(timestamp)) - def get_remote_index(self) -> dict: + def get_remote_index(self) -> Optional[dict]: url = self.github_raw_url.format(self.index_owner, self.index_repo, self.index_branch, self.index_path) res = requests.get(url) @@ -94,7 +95,7 @@ class IndicatorsUpdates: return yaml.safe_load(res.content) - def download_remote_ioc(self, ioc_url: str) -> str: + def download_remote_ioc(self, ioc_url: str) -> Optional[str]: res = requests.get(ioc_url) if res.status_code != 200: log.error("Failed to download indicators file from %s (error %d)", @@ -116,6 +117,9 @@ class IndicatorsUpdates: os.makedirs(MVT_INDICATORS_FOLDER) index = self.get_remote_index() + if not index: + return + for ioc in index.get("indicators", []): ioc_type = ioc.get("type", "") @@ -171,7 +175,7 @@ class IndicatorsUpdates: return latest_commit_ts - def should_check(self) -> (bool, int): + def should_check(self) -> Tuple[bool, int]: now = datetime.utcnow() latest_check_ts = self.get_latest_check() latest_check_dt = datetime.fromtimestamp(latest_check_ts) @@ -182,7 +186,7 @@ class IndicatorsUpdates: if diff_hours >= INDICATORS_CHECK_FREQUENCY: return True, 0 - return False, INDICATORS_CHECK_FREQUENCY - diff_hours + return False, int(INDICATORS_CHECK_FREQUENCY - diff_hours) def check(self) -> bool: self.set_latest_check() @@ -197,6 +201,9 @@ class IndicatorsUpdates: return True index = self.get_remote_index() + if not index: + return False + for ioc in index.get("indicators", []): if ioc.get("type", "") != "github": continue diff --git a/mvt/common/utils.py b/mvt/common/utils.py index bd4ad6c..76ee21f 100644 --- a/mvt/common/utils.py +++ b/mvt/common/utils.py @@ -7,7 +7,7 @@ import datetime import hashlib import os import re -from typing import Iterator, Union +from typing import Any, Iterator, Union def convert_chrometime_to_datetime(timestamp: int) -> datetime.datetime: @@ -51,7 +51,7 @@ def convert_unix_to_utc_datetime( return datetime.datetime.utcfromtimestamp(float(timestamp)) -def convert_unix_to_iso(timestamp: int) -> str: +def convert_unix_to_iso(timestamp: Union[int, float, str]) -> str: """Converts a unix epoch to ISO string. :param timestamp: Epoc timestamp to convert. @@ -125,7 +125,7 @@ def check_for_links(text: str) -> list: # Note: taken from here: # https://stackoverflow.com/questions/57014259/json-dumps-on-dictionary-with-bytes-for-keys -def keys_bytes_to_string(obj) -> str: +def keys_bytes_to_string(obj: Any) -> Any: """Convert object keys from bytes to string. :param obj: Object to convert from bytes to string.