update docs

This commit is contained in:
Tran Xen
2023-08-04 00:19:04 +02:00
parent 02d88bac91
commit be1cd15432
9 changed files with 179 additions and 94 deletions
+79
View File
@@ -0,0 +1,79 @@
import os
from tqdm import tqdm
import urllib.request
from scripts.faceswaplab_utils.faceswaplab_logging import logger
from scripts.faceswaplab_swapping.swapper import is_sha1_matching
from scripts.faceswaplab_utils.models_utils import get_models
from scripts.faceswaplab_globals import *
from packaging import version
import pkg_resources
ALREADY_DONE = False
def check_configuration() -> None:
global ALREADY_DONE
if ALREADY_DONE:
return
logger.info(f"FaceSwapLab {VERSION_FLAG} Config :")
# This has been moved here due to pb with sdnext in install.py not doing what a1111 is doing.
models_dir = MODELS_DIR
faces_dir = FACES_DIR
model_url = "https://huggingface.co/henryruhs/roop/resolve/main/inswapper_128.onnx"
model_name = os.path.basename(model_url)
model_path = os.path.join(models_dir, model_name)
def download(url: str, path: str) -> None:
request = urllib.request.urlopen(url)
total = int(request.headers.get("Content-Length", 0))
with tqdm(
total=total,
desc="Downloading inswapper model",
unit="B",
unit_scale=True,
unit_divisor=1024,
) as progress:
urllib.request.urlretrieve(
url,
path,
reporthook=lambda count, block_size, total_size: progress.update(
block_size
),
)
os.makedirs(models_dir, exist_ok=True)
os.makedirs(faces_dir, exist_ok=True)
if not is_sha1_matching(model_path, EXPECTED_INSWAPPER_SHA1):
logger.error(
"Suspicious sha1 for model %s, check the model is valid or has been downloaded adequately. Should be %s",
model_path,
EXPECTED_INSWAPPER_SHA1,
)
gradio_version = pkg_resources.get_distribution("gradio").version
if version.parse(gradio_version) < version.parse("3.32.0"):
logger.warning(
"Errors may occur with gradio versions lower than 3.32.0. Your version : %s",
gradio_version,
)
if not os.path.exists(model_path):
download(model_url, model_path)
def print_infos() -> None:
logger.info("FaceSwapLab config :")
logger.info("+ MODEL DIR : %s", models_dir)
models = get_models()
logger.info("+ MODELS: %s", models)
logger.info("+ FACES DIR : %s", faces_dir)
logger.info("+ ANALYZER DIR : %s", ANALYZER_DIR)
print_infos()
ALREADY_DONE = True
+2 -1
View File
@@ -2,6 +2,7 @@ import importlib
import traceback
from scripts import faceswaplab_globals
from scripts.configure import check_configuration
from scripts.faceswaplab_api import faceswaplab_api
from scripts.faceswaplab_postprocessing import upscaling
from scripts.faceswaplab_settings import faceswaplab_settings
@@ -65,8 +66,8 @@ except:
class FaceSwapScript(scripts.Script):
def __init__(self) -> None:
logger.info(f"FaceSwapLab {VERSION_FLAG}")
super().__init__()
check_configuration()
@property
def units_count(self) -> int:
+3
View File
@@ -4,6 +4,8 @@ from modules import scripts
MODELS_DIR = os.path.abspath(os.path.join("models", "faceswaplab"))
ANALYZER_DIR = os.path.abspath(os.path.join(MODELS_DIR, "analysers"))
FACE_PARSER_DIR = os.path.abspath(os.path.join(MODELS_DIR, "parser"))
FACES_DIR = os.path.abspath(os.path.join(MODELS_DIR, "faces"))
REFERENCE_PATH = os.path.join(
scripts.basedir(), "extensions", "sd-webui-faceswaplab", "references"
)
@@ -13,3 +15,4 @@ EXTENSION_PATH = os.path.join("extensions", "sd-webui-faceswaplab")
# The NSFW score threshold. If any part of the image has a score greater than this threshold, the image will be considered NSFW.
NSFW_SCORE_THRESHOLD: float = 0.7
EXPECTED_INSWAPPER_SHA1 = "17a64851eaefd55ea597ee41e5c18409754244c5"
+21 -11
View File
@@ -2,6 +2,7 @@ import copy
import os
from dataclasses import dataclass
from pprint import pformat
import traceback
from typing import Any, Dict, Generator, List, Set, Tuple, Optional
import tempfile
from tqdm import tqdm
@@ -271,7 +272,9 @@ def getAnalysisModel() -> insightface.app.FaceAnalysis:
logger.info("Load analysis model, will take some time. (> 30s)")
# Initialize the analysis model with the specified name and providers
with tqdm(total=1, desc="Loading analysis model", unit="model") as pbar:
with tqdm(
total=1, desc="Loading analysis model (first time is slow)", unit="model"
) as pbar:
with capture_stdout() as captured:
model = insightface.app.FaceAnalysis(
name="buffalo_l",
@@ -291,14 +294,21 @@ def getAnalysisModel() -> insightface.app.FaceAnalysis:
def is_sha1_matching(file_path: str, expected_sha1: str) -> bool:
sha1_hash = hashlib.sha1(usedforsecurity=False)
with open(file_path, "rb") as file:
for byte_block in iter(lambda: file.read(4096), b""):
sha1_hash.update(byte_block)
if sha1_hash.hexdigest() == expected_sha1:
return True
else:
return False
try:
with open(file_path, "rb") as file:
for byte_block in iter(lambda: file.read(4096), b""):
sha1_hash.update(byte_block)
if sha1_hash.hexdigest() == expected_sha1:
return True
else:
return False
except Exception as e:
logger.error(
"Failed to check model hash, check the model is valid or has been downloaded adequately : %e",
e,
)
traceback.print_exc()
return False
@lru_cache(maxsize=1)
@@ -334,8 +344,6 @@ def getFaceSwapModel(model_path: str) -> upscaled_inswapper.UpscaledINSwapper:
logger.error(
"Loading of swapping model failed, please check the requirements (On Windows, download and install Visual Studio. During the install, make sure to include the Python and C++ packages.)"
)
import traceback
traceback.print_exc()
raise FaceModelException("Loading of swapping model failed")
@@ -379,6 +387,8 @@ def get_faces(
# Sort the detected faces based on their x-coordinate of the bounding box
return sorted(face, key=lambda x: x.bbox[0])
except Exception as e:
logger.error("Failed to get faces %s", e)
traceback.print_exc()
return []
@@ -193,7 +193,7 @@ class UpscaledINSwapper(INSwapper):
if options:
logger.info("*" * 80)
logger.info(f"Upscaled inswapper")
logger.info(f"Inswapper")
if options.upscaler_name:
# Upscale original image