mirror of
https://github.com/mvt-project/mvt.git
synced 2026-02-13 00:52:44 +00:00
Improves typing
This commit is contained in:
@@ -31,6 +31,7 @@ class ChromeHistory(AndroidExtraction):
|
||||
super().__init__(file_path=file_path, target_path=target_path,
|
||||
results_path=results_path, fast_mode=fast_mode,
|
||||
log=log, results=results)
|
||||
self.results = []
|
||||
|
||||
def serialize(self, record: dict) -> Union[dict, list]:
|
||||
return {
|
||||
@@ -55,6 +56,7 @@ class ChromeHistory(AndroidExtraction):
|
||||
:param db_path: Path to the History database to process.
|
||||
|
||||
"""
|
||||
assert isinstance(self.results, list) # assert results type for mypy
|
||||
conn = sqlite3.connect(db_path)
|
||||
cur = conn.cursor()
|
||||
cur.execute("""
|
||||
|
||||
@@ -39,7 +39,7 @@ class Files(AndroidExtraction):
|
||||
log=log, results=results)
|
||||
self.full_find = False
|
||||
|
||||
def serialize(self, record: dict) -> Union[dict, list]:
|
||||
def serialize(self, record: dict) -> Union[dict, list, None]:
|
||||
if "modified_time" in record:
|
||||
return {
|
||||
"timestamp": record["modified_time"],
|
||||
@@ -62,6 +62,9 @@ class Files(AndroidExtraction):
|
||||
self.detected.append(result)
|
||||
|
||||
def backup_file(self, file_path: str) -> None:
|
||||
if not self.results_path:
|
||||
return
|
||||
|
||||
local_file_name = file_path.replace("/", "_").replace(" ", "-")
|
||||
local_files_folder = os.path.join(self.results_path, "files")
|
||||
if not os.path.exists(local_files_folder):
|
||||
@@ -79,6 +82,7 @@ class Files(AndroidExtraction):
|
||||
file_path, local_file_path)
|
||||
|
||||
def find_files(self, folder: str) -> None:
|
||||
assert isinstance(self.results, list)
|
||||
if self.full_find:
|
||||
cmd = f"find '{folder}' -type f -printf '%T@ %m %s %u %g %p\n' 2> /dev/null"
|
||||
output = self._adb_command(cmd)
|
||||
@@ -121,8 +125,7 @@ class Files(AndroidExtraction):
|
||||
for entry in self.results:
|
||||
self.log.info("Found file in tmp folder at path %s",
|
||||
entry.get("path"))
|
||||
if self.results_path:
|
||||
self.backup_file(entry.get("path"))
|
||||
self.backup_file(entry.get("path"))
|
||||
|
||||
for media_folder in ANDROID_MEDIA_FOLDERS:
|
||||
self.find_files(media_folder)
|
||||
|
||||
@@ -30,6 +30,7 @@ class CmdCheckIOCS(Command):
|
||||
self.name = "check-iocs"
|
||||
|
||||
def run(self) -> None:
|
||||
assert self.target_path is not None
|
||||
all_modules = []
|
||||
for entry in self.modules:
|
||||
if entry not in all_modules:
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Dict, Iterator, List, Optional, Union
|
||||
|
||||
from appdirs import user_data_dir
|
||||
|
||||
@@ -23,7 +23,7 @@ class Indicators:
|
||||
|
||||
def __init__(self, log=logging.Logger) -> None:
|
||||
self.log = log
|
||||
self.ioc_collections = []
|
||||
self.ioc_collections: List[Dict[str, Any]] = []
|
||||
self.total_ioc_count = 0
|
||||
|
||||
def _load_downloaded_indicators(self) -> None:
|
||||
@@ -209,7 +209,7 @@ class Indicators:
|
||||
self.log.info("Loaded a total of %d unique indicators",
|
||||
self.total_ioc_count)
|
||||
|
||||
def get_iocs(self, ioc_type: str) -> Union[dict, None]:
|
||||
def get_iocs(self, ioc_type: str) -> Union[Iterator[Dict[str, Any]], None]:
|
||||
for ioc_collection in self.ioc_collections:
|
||||
for ioc in ioc_collection.get(ioc_type, []):
|
||||
yield {
|
||||
|
||||
@@ -7,7 +7,7 @@ import csv
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import simplejson as json
|
||||
|
||||
@@ -28,7 +28,7 @@ class MVTModule:
|
||||
"""This class provides a base for all extraction modules."""
|
||||
|
||||
enabled = True
|
||||
slug = None
|
||||
slug: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -61,12 +61,12 @@ class MVTModule:
|
||||
self.log = log
|
||||
self.indicators = None
|
||||
self.results = results if results else []
|
||||
self.detected = []
|
||||
self.timeline = []
|
||||
self.timeline_detected = []
|
||||
self.detected: List[Dict[str, Any]] = []
|
||||
self.timeline: List[Dict[str, str]] = []
|
||||
self.timeline_detected: List[Dict[str, str]] = []
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_path: str, log: logging.Logger = None):
|
||||
def from_json(cls, json_path: str, log: logging.Logger):
|
||||
with open(json_path, "r", encoding="utf-8") as handle:
|
||||
results = json.load(handle)
|
||||
if log:
|
||||
@@ -116,7 +116,7 @@ class MVTModule:
|
||||
with open(detected_json_path, "w", encoding="utf-8") as handle:
|
||||
json.dump(self.detected, handle, indent=4, default=str)
|
||||
|
||||
def serialize(self, record: dict) -> Union[dict, list]:
|
||||
def serialize(self, record: dict) -> Union[dict, list, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@@ -159,7 +159,7 @@ class MVTModule:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def run_module(module: Callable) -> None:
|
||||
def run_module(module: MVTModule) -> None:
|
||||
module.log.info("Running module %s...", module.__class__.__name__)
|
||||
|
||||
try:
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
@@ -83,7 +84,7 @@ class IndicatorsUpdates:
|
||||
with open(self.latest_update_path, "w", encoding="utf-8") as handle:
|
||||
handle.write(str(timestamp))
|
||||
|
||||
def get_remote_index(self) -> dict:
|
||||
def get_remote_index(self) -> Optional[dict]:
|
||||
url = self.github_raw_url.format(self.index_owner, self.index_repo,
|
||||
self.index_branch, self.index_path)
|
||||
res = requests.get(url)
|
||||
@@ -94,7 +95,7 @@ class IndicatorsUpdates:
|
||||
|
||||
return yaml.safe_load(res.content)
|
||||
|
||||
def download_remote_ioc(self, ioc_url: str) -> str:
|
||||
def download_remote_ioc(self, ioc_url: str) -> Optional[str]:
|
||||
res = requests.get(ioc_url)
|
||||
if res.status_code != 200:
|
||||
log.error("Failed to download indicators file from %s (error %d)",
|
||||
@@ -116,6 +117,9 @@ class IndicatorsUpdates:
|
||||
os.makedirs(MVT_INDICATORS_FOLDER)
|
||||
|
||||
index = self.get_remote_index()
|
||||
if not index:
|
||||
return
|
||||
|
||||
for ioc in index.get("indicators", []):
|
||||
ioc_type = ioc.get("type", "")
|
||||
|
||||
@@ -171,7 +175,7 @@ class IndicatorsUpdates:
|
||||
|
||||
return latest_commit_ts
|
||||
|
||||
def should_check(self) -> (bool, int):
|
||||
def should_check(self) -> Tuple[bool, int]:
|
||||
now = datetime.utcnow()
|
||||
latest_check_ts = self.get_latest_check()
|
||||
latest_check_dt = datetime.fromtimestamp(latest_check_ts)
|
||||
@@ -182,7 +186,7 @@ class IndicatorsUpdates:
|
||||
if diff_hours >= INDICATORS_CHECK_FREQUENCY:
|
||||
return True, 0
|
||||
|
||||
return False, INDICATORS_CHECK_FREQUENCY - diff_hours
|
||||
return False, int(INDICATORS_CHECK_FREQUENCY - diff_hours)
|
||||
|
||||
def check(self) -> bool:
|
||||
self.set_latest_check()
|
||||
@@ -197,6 +201,9 @@ class IndicatorsUpdates:
|
||||
return True
|
||||
|
||||
index = self.get_remote_index()
|
||||
if not index:
|
||||
return False
|
||||
|
||||
for ioc in index.get("indicators", []):
|
||||
if ioc.get("type", "") != "github":
|
||||
continue
|
||||
|
||||
@@ -7,7 +7,7 @@ import datetime
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
from typing import Iterator, Union
|
||||
from typing import Any, Iterator, Union
|
||||
|
||||
|
||||
def convert_chrometime_to_datetime(timestamp: int) -> datetime.datetime:
|
||||
@@ -51,7 +51,7 @@ def convert_unix_to_utc_datetime(
|
||||
return datetime.datetime.utcfromtimestamp(float(timestamp))
|
||||
|
||||
|
||||
def convert_unix_to_iso(timestamp: int) -> str:
|
||||
def convert_unix_to_iso(timestamp: Union[int, float, str]) -> str:
|
||||
"""Converts a unix epoch to ISO string.
|
||||
|
||||
:param timestamp: Epoc timestamp to convert.
|
||||
@@ -125,7 +125,7 @@ def check_for_links(text: str) -> list:
|
||||
|
||||
# Note: taken from here:
|
||||
# https://stackoverflow.com/questions/57014259/json-dumps-on-dictionary-with-bytes-for-keys
|
||||
def keys_bytes_to_string(obj) -> str:
|
||||
def keys_bytes_to_string(obj: Any) -> Any:
|
||||
"""Convert object keys from bytes to string.
|
||||
|
||||
:param obj: Object to convert from bytes to string.
|
||||
|
||||
Reference in New Issue
Block a user