update docs
This commit is contained in:
@@ -0,0 +1,79 @@
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import urllib.request
|
||||
from scripts.faceswaplab_utils.faceswaplab_logging import logger
|
||||
from scripts.faceswaplab_swapping.swapper import is_sha1_matching
|
||||
from scripts.faceswaplab_utils.models_utils import get_models
|
||||
from scripts.faceswaplab_globals import *
|
||||
from packaging import version
|
||||
import pkg_resources
|
||||
|
||||
ALREADY_DONE = False
|
||||
|
||||
|
||||
def check_configuration() -> None:
|
||||
global ALREADY_DONE
|
||||
|
||||
if ALREADY_DONE:
|
||||
return
|
||||
|
||||
logger.info(f"FaceSwapLab {VERSION_FLAG} Config :")
|
||||
|
||||
# This has been moved here due to pb with sdnext in install.py not doing what a1111 is doing.
|
||||
models_dir = MODELS_DIR
|
||||
faces_dir = FACES_DIR
|
||||
|
||||
model_url = "https://huggingface.co/henryruhs/roop/resolve/main/inswapper_128.onnx"
|
||||
model_name = os.path.basename(model_url)
|
||||
model_path = os.path.join(models_dir, model_name)
|
||||
|
||||
def download(url: str, path: str) -> None:
|
||||
request = urllib.request.urlopen(url)
|
||||
total = int(request.headers.get("Content-Length", 0))
|
||||
with tqdm(
|
||||
total=total,
|
||||
desc="Downloading inswapper model",
|
||||
unit="B",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as progress:
|
||||
urllib.request.urlretrieve(
|
||||
url,
|
||||
path,
|
||||
reporthook=lambda count, block_size, total_size: progress.update(
|
||||
block_size
|
||||
),
|
||||
)
|
||||
|
||||
os.makedirs(models_dir, exist_ok=True)
|
||||
os.makedirs(faces_dir, exist_ok=True)
|
||||
|
||||
if not is_sha1_matching(model_path, EXPECTED_INSWAPPER_SHA1):
|
||||
logger.error(
|
||||
"Suspicious sha1 for model %s, check the model is valid or has been downloaded adequately. Should be %s",
|
||||
model_path,
|
||||
EXPECTED_INSWAPPER_SHA1,
|
||||
)
|
||||
|
||||
gradio_version = pkg_resources.get_distribution("gradio").version
|
||||
|
||||
if version.parse(gradio_version) < version.parse("3.32.0"):
|
||||
logger.warning(
|
||||
"Errors may occur with gradio versions lower than 3.32.0. Your version : %s",
|
||||
gradio_version,
|
||||
)
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
download(model_url, model_path)
|
||||
|
||||
def print_infos() -> None:
|
||||
logger.info("FaceSwapLab config :")
|
||||
logger.info("+ MODEL DIR : %s", models_dir)
|
||||
models = get_models()
|
||||
logger.info("+ MODELS: %s", models)
|
||||
logger.info("+ FACES DIR : %s", faces_dir)
|
||||
logger.info("+ ANALYZER DIR : %s", ANALYZER_DIR)
|
||||
|
||||
print_infos()
|
||||
|
||||
ALREADY_DONE = True
|
||||
@@ -2,6 +2,7 @@ import importlib
|
||||
import traceback
|
||||
|
||||
from scripts import faceswaplab_globals
|
||||
from scripts.configure import check_configuration
|
||||
from scripts.faceswaplab_api import faceswaplab_api
|
||||
from scripts.faceswaplab_postprocessing import upscaling
|
||||
from scripts.faceswaplab_settings import faceswaplab_settings
|
||||
@@ -65,8 +66,8 @@ except:
|
||||
|
||||
class FaceSwapScript(scripts.Script):
|
||||
def __init__(self) -> None:
|
||||
logger.info(f"FaceSwapLab {VERSION_FLAG}")
|
||||
super().__init__()
|
||||
check_configuration()
|
||||
|
||||
@property
|
||||
def units_count(self) -> int:
|
||||
|
||||
@@ -4,6 +4,8 @@ from modules import scripts
|
||||
MODELS_DIR = os.path.abspath(os.path.join("models", "faceswaplab"))
|
||||
ANALYZER_DIR = os.path.abspath(os.path.join(MODELS_DIR, "analysers"))
|
||||
FACE_PARSER_DIR = os.path.abspath(os.path.join(MODELS_DIR, "parser"))
|
||||
FACES_DIR = os.path.abspath(os.path.join(MODELS_DIR, "faces"))
|
||||
|
||||
REFERENCE_PATH = os.path.join(
|
||||
scripts.basedir(), "extensions", "sd-webui-faceswaplab", "references"
|
||||
)
|
||||
@@ -13,3 +15,4 @@ EXTENSION_PATH = os.path.join("extensions", "sd-webui-faceswaplab")
|
||||
|
||||
# The NSFW score threshold. If any part of the image has a score greater than this threshold, the image will be considered NSFW.
|
||||
NSFW_SCORE_THRESHOLD: float = 0.7
|
||||
EXPECTED_INSWAPPER_SHA1 = "17a64851eaefd55ea597ee41e5c18409754244c5"
|
||||
|
||||
@@ -2,6 +2,7 @@ import copy
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pprint import pformat
|
||||
import traceback
|
||||
from typing import Any, Dict, Generator, List, Set, Tuple, Optional
|
||||
import tempfile
|
||||
from tqdm import tqdm
|
||||
@@ -271,7 +272,9 @@ def getAnalysisModel() -> insightface.app.FaceAnalysis:
|
||||
logger.info("Load analysis model, will take some time. (> 30s)")
|
||||
# Initialize the analysis model with the specified name and providers
|
||||
|
||||
with tqdm(total=1, desc="Loading analysis model", unit="model") as pbar:
|
||||
with tqdm(
|
||||
total=1, desc="Loading analysis model (first time is slow)", unit="model"
|
||||
) as pbar:
|
||||
with capture_stdout() as captured:
|
||||
model = insightface.app.FaceAnalysis(
|
||||
name="buffalo_l",
|
||||
@@ -291,14 +294,21 @@ def getAnalysisModel() -> insightface.app.FaceAnalysis:
|
||||
|
||||
def is_sha1_matching(file_path: str, expected_sha1: str) -> bool:
|
||||
sha1_hash = hashlib.sha1(usedforsecurity=False)
|
||||
|
||||
with open(file_path, "rb") as file:
|
||||
for byte_block in iter(lambda: file.read(4096), b""):
|
||||
sha1_hash.update(byte_block)
|
||||
if sha1_hash.hexdigest() == expected_sha1:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
with open(file_path, "rb") as file:
|
||||
for byte_block in iter(lambda: file.read(4096), b""):
|
||||
sha1_hash.update(byte_block)
|
||||
if sha1_hash.hexdigest() == expected_sha1:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to check model hash, check the model is valid or has been downloaded adequately : %e",
|
||||
e,
|
||||
)
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
@@ -334,8 +344,6 @@ def getFaceSwapModel(model_path: str) -> upscaled_inswapper.UpscaledINSwapper:
|
||||
logger.error(
|
||||
"Loading of swapping model failed, please check the requirements (On Windows, download and install Visual Studio. During the install, make sure to include the Python and C++ packages.)"
|
||||
)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise FaceModelException("Loading of swapping model failed")
|
||||
|
||||
@@ -379,6 +387,8 @@ def get_faces(
|
||||
# Sort the detected faces based on their x-coordinate of the bounding box
|
||||
return sorted(face, key=lambda x: x.bbox[0])
|
||||
except Exception as e:
|
||||
logger.error("Failed to get faces %s", e)
|
||||
traceback.print_exc()
|
||||
return []
|
||||
|
||||
|
||||
|
||||
@@ -193,7 +193,7 @@ class UpscaledINSwapper(INSwapper):
|
||||
|
||||
if options:
|
||||
logger.info("*" * 80)
|
||||
logger.info(f"Upscaled inswapper")
|
||||
logger.info(f"Inswapper")
|
||||
|
||||
if options.upscaler_name:
|
||||
# Upscale original image
|
||||
|
||||
Reference in New Issue
Block a user