Improves typing

This commit is contained in:
tek
2023-03-24 19:02:02 +01:00
parent 1ad176788b
commit 49491800fb
7 changed files with 34 additions and 21 deletions

View File

@@ -31,6 +31,7 @@ class ChromeHistory(AndroidExtraction):
super().__init__(file_path=file_path, target_path=target_path,
results_path=results_path, fast_mode=fast_mode,
log=log, results=results)
self.results = []
def serialize(self, record: dict) -> Union[dict, list]:
return {
@@ -55,6 +56,7 @@ class ChromeHistory(AndroidExtraction):
:param db_path: Path to the History database to process.
"""
assert isinstance(self.results, list) # assert results type for mypy
conn = sqlite3.connect(db_path)
cur = conn.cursor()
cur.execute("""

View File

@@ -39,7 +39,7 @@ class Files(AndroidExtraction):
log=log, results=results)
self.full_find = False
def serialize(self, record: dict) -> Union[dict, list]:
def serialize(self, record: dict) -> Union[dict, list, None]:
if "modified_time" in record:
return {
"timestamp": record["modified_time"],
@@ -62,6 +62,9 @@ class Files(AndroidExtraction):
self.detected.append(result)
def backup_file(self, file_path: str) -> None:
if not self.results_path:
return
local_file_name = file_path.replace("/", "_").replace(" ", "-")
local_files_folder = os.path.join(self.results_path, "files")
if not os.path.exists(local_files_folder):
@@ -79,6 +82,7 @@ class Files(AndroidExtraction):
file_path, local_file_path)
def find_files(self, folder: str) -> None:
assert isinstance(self.results, list)
if self.full_find:
cmd = f"find '{folder}' -type f -printf '%T@ %m %s %u %g %p\n' 2> /dev/null"
output = self._adb_command(cmd)
@@ -121,8 +125,7 @@ class Files(AndroidExtraction):
for entry in self.results:
self.log.info("Found file in tmp folder at path %s",
entry.get("path"))
if self.results_path:
self.backup_file(entry.get("path"))
self.backup_file(entry.get("path"))
for media_folder in ANDROID_MEDIA_FOLDERS:
self.find_files(media_folder)

View File

@@ -30,6 +30,7 @@ class CmdCheckIOCS(Command):
self.name = "check-iocs"
def run(self) -> None:
assert self.target_path is not None
all_modules = []
for entry in self.modules:
if entry not in all_modules:

View File

@@ -6,7 +6,7 @@
import json
import logging
import os
from typing import Optional, Union
from typing import Any, Dict, Iterator, List, Optional, Union
from appdirs import user_data_dir
@@ -23,7 +23,7 @@ class Indicators:
def __init__(self, log=logging.Logger) -> None:
self.log = log
self.ioc_collections = []
self.ioc_collections: List[Dict[str, Any]] = []
self.total_ioc_count = 0
def _load_downloaded_indicators(self) -> None:
@@ -209,7 +209,7 @@ class Indicators:
self.log.info("Loaded a total of %d unique indicators",
self.total_ioc_count)
def get_iocs(self, ioc_type: str) -> Union[dict, None]:
def get_iocs(self, ioc_type: str) -> Union[Iterator[Dict[str, Any]], None]:
for ioc_collection in self.ioc_collections:
for ioc in ioc_collection.get(ioc_type, []):
yield {

View File

@@ -7,7 +7,7 @@ import csv
import logging
import os
import re
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union
import simplejson as json
@@ -28,7 +28,7 @@ class MVTModule:
"""This class provides a base for all extraction modules."""
enabled = True
slug = None
slug: Optional[str] = None
def __init__(
self,
@@ -61,12 +61,12 @@ class MVTModule:
self.log = log
self.indicators = None
self.results = results if results else []
self.detected = []
self.timeline = []
self.timeline_detected = []
self.detected: List[Dict[str, Any]] = []
self.timeline: List[Dict[str, str]] = []
self.timeline_detected: List[Dict[str, str]] = []
@classmethod
def from_json(cls, json_path: str, log: logging.Logger = None):
def from_json(cls, json_path: str, log: logging.Logger):
with open(json_path, "r", encoding="utf-8") as handle:
results = json.load(handle)
if log:
@@ -116,7 +116,7 @@ class MVTModule:
with open(detected_json_path, "w", encoding="utf-8") as handle:
json.dump(self.detected, handle, indent=4, default=str)
def serialize(self, record: dict) -> Union[dict, list]:
def serialize(self, record: dict) -> Union[dict, list, None]:
raise NotImplementedError
@staticmethod
@@ -159,7 +159,7 @@ class MVTModule:
raise NotImplementedError
def run_module(module: Callable) -> None:
def run_module(module: MVTModule) -> None:
module.log.info("Running module %s...", module.__class__.__name__)
try:

View File

@@ -6,6 +6,7 @@
import logging
import os
from datetime import datetime
from typing import Optional, Tuple
import requests
import yaml
@@ -83,7 +84,7 @@ class IndicatorsUpdates:
with open(self.latest_update_path, "w", encoding="utf-8") as handle:
handle.write(str(timestamp))
def get_remote_index(self) -> dict:
def get_remote_index(self) -> Optional[dict]:
url = self.github_raw_url.format(self.index_owner, self.index_repo,
self.index_branch, self.index_path)
res = requests.get(url)
@@ -94,7 +95,7 @@ class IndicatorsUpdates:
return yaml.safe_load(res.content)
def download_remote_ioc(self, ioc_url: str) -> str:
def download_remote_ioc(self, ioc_url: str) -> Optional[str]:
res = requests.get(ioc_url)
if res.status_code != 200:
log.error("Failed to download indicators file from %s (error %d)",
@@ -116,6 +117,9 @@ class IndicatorsUpdates:
os.makedirs(MVT_INDICATORS_FOLDER)
index = self.get_remote_index()
if not index:
return
for ioc in index.get("indicators", []):
ioc_type = ioc.get("type", "")
@@ -171,7 +175,7 @@ class IndicatorsUpdates:
return latest_commit_ts
def should_check(self) -> (bool, int):
def should_check(self) -> Tuple[bool, int]:
now = datetime.utcnow()
latest_check_ts = self.get_latest_check()
latest_check_dt = datetime.fromtimestamp(latest_check_ts)
@@ -182,7 +186,7 @@ class IndicatorsUpdates:
if diff_hours >= INDICATORS_CHECK_FREQUENCY:
return True, 0
return False, INDICATORS_CHECK_FREQUENCY - diff_hours
return False, int(INDICATORS_CHECK_FREQUENCY - diff_hours)
def check(self) -> bool:
self.set_latest_check()
@@ -197,6 +201,9 @@ class IndicatorsUpdates:
return True
index = self.get_remote_index()
if not index:
return False
for ioc in index.get("indicators", []):
if ioc.get("type", "") != "github":
continue

View File

@@ -7,7 +7,7 @@ import datetime
import hashlib
import os
import re
from typing import Iterator, Union
from typing import Any, Iterator, Union
def convert_chrometime_to_datetime(timestamp: int) -> datetime.datetime:
@@ -51,7 +51,7 @@ def convert_unix_to_utc_datetime(
return datetime.datetime.utcfromtimestamp(float(timestamp))
def convert_unix_to_iso(timestamp: int) -> str:
def convert_unix_to_iso(timestamp: Union[int, float, str]) -> str:
"""Converts a unix epoch to ISO string.
:param timestamp: Epoc timestamp to convert.
@@ -125,7 +125,7 @@ def check_for_links(text: str) -> list:
# Note: taken from here:
# https://stackoverflow.com/questions/57014259/json-dumps-on-dictionary-with-bytes-for-keys
def keys_bytes_to_string(obj) -> str:
def keys_bytes_to_string(obj: Any) -> Any:
"""Convert object keys from bytes to string.
:param obj: Object to convert from bytes to string.