fix similarity, add checksum for swapper, fix minor bugs
This commit is contained in:
@@ -20,6 +20,10 @@ In short:
|
||||
|
||||
More on this here : https://glucauze.github.io/sd-webui-faceswaplab/
|
||||
|
||||
### Known problems (wontfix):
|
||||
|
||||
+ Older versions of gradio don't work well with the extension. See this bug : https://github.com/glucauze/sd-webui-faceswaplab/issues/5
|
||||
|
||||
### Features
|
||||
|
||||
+ **Face Unit Concept**: Similar to controlNet, the program introduces the concept of a face unit. You can configure up to 10 units (3 units are the default setting) in the program settings (sd).
|
||||
|
||||
+1
-1
@@ -2,7 +2,7 @@ cython
|
||||
ifnude
|
||||
insightface==0.7.3
|
||||
onnx==1.14.0
|
||||
onnxruntime==1.15.0
|
||||
onnxruntime==1.15.1
|
||||
opencv-python==4.7.0.72
|
||||
pandas
|
||||
pydantic==1.10.9
|
||||
|
||||
@@ -27,7 +27,6 @@ from scripts.faceswaplab_postprocessing.postprocessing_options import (
|
||||
PostProcessingOptions,
|
||||
)
|
||||
from scripts.faceswaplab_utils.models_utils import get_current_model
|
||||
import gradio as gr
|
||||
from scripts.faceswaplab_utils.typing import CV2ImgU8, PILImage, Face
|
||||
from scripts.faceswaplab_inpainting.i2i_pp import img2img_diffusion
|
||||
|
||||
@@ -250,6 +249,21 @@ def getAnalysisModel() -> insightface.app.FaceAnalysis:
|
||||
raise FaceModelException("Loading of analysis model failed")
|
||||
|
||||
|
||||
import hashlib
|
||||
|
||||
|
||||
def is_sha1_matching(file_path: str, expected_sha1: str) -> bool:
|
||||
sha1_hash = hashlib.sha1()
|
||||
|
||||
with open(file_path, "rb") as file:
|
||||
for byte_block in iter(lambda: file.read(4096), b""):
|
||||
sha1_hash.update(byte_block)
|
||||
if sha1_hash.hexdigest() == expected_sha1:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def getFaceSwapModel(model_path: str) -> upscaled_inswapper.UpscaledINSwapper:
|
||||
"""
|
||||
@@ -262,6 +276,14 @@ def getFaceSwapModel(model_path: str) -> upscaled_inswapper.UpscaledINSwapper:
|
||||
insightface.model_zoo.FaceModel: The face swap model.
|
||||
"""
|
||||
try:
|
||||
expected_sha1 = "17a64851eaefd55ea597ee41e5c18409754244c5"
|
||||
if not is_sha1_matching(model_path, expected_sha1):
|
||||
logger.error(
|
||||
"Suspicious sha1 for model %s, check the model is valid or has been downloaded adequately. Should be %s",
|
||||
model_path,
|
||||
expected_sha1,
|
||||
)
|
||||
|
||||
# Initializes the face swap model using the specified model path.
|
||||
return upscaled_inswapper.UpscaledINSwapper(
|
||||
insightface.model_zoo.get_model(model_path, providers=providers)
|
||||
@@ -270,6 +292,9 @@ def getFaceSwapModel(model_path: str) -> upscaled_inswapper.UpscaledINSwapper:
|
||||
logger.error(
|
||||
"Loading of swapping model failed, please check the requirements (On Windows, download and install Visual Studio. During the install, make sure to include the Python and C++ packages.)"
|
||||
)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise FaceModelException("Loading of swapping model failed")
|
||||
|
||||
|
||||
@@ -315,11 +340,15 @@ def get_faces(
|
||||
return []
|
||||
|
||||
|
||||
@dataclass
|
||||
class FaceFilteringOptions:
|
||||
faces_index: Set[int]
|
||||
source_gender: Optional[int] = None # if none will not use same gender
|
||||
sort_by_face_size: bool = False
|
||||
|
||||
|
||||
def filter_faces(
|
||||
all_faces: List[Face],
|
||||
faces_index: Set[int],
|
||||
source_gender: int = None,
|
||||
sort_by_face_size: bool = False,
|
||||
all_faces: List[Face], filtering_options: FaceFilteringOptions
|
||||
) -> List[Face]:
|
||||
"""
|
||||
Sorts and filters a list of faces based on specified criteria.
|
||||
@@ -337,18 +366,24 @@ def filter_faces(
|
||||
:return: A list of Face objects sorted and filtered according to the specified criteria.
|
||||
"""
|
||||
filtered_faces = copy.copy(all_faces)
|
||||
if sort_by_face_size:
|
||||
if filtering_options.sort_by_face_size:
|
||||
filtered_faces = sorted(
|
||||
all_faces,
|
||||
reverse=True,
|
||||
key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]),
|
||||
)
|
||||
|
||||
if source_gender is not None:
|
||||
if filtering_options.source_gender is not None:
|
||||
filtered_faces = [
|
||||
face for face in filtered_faces if face["gender"] == source_gender
|
||||
face
|
||||
for face in filtered_faces
|
||||
if face["gender"] == filtering_options.source_gender
|
||||
]
|
||||
return [face for i, face in enumerate(filtered_faces) if i in faces_index]
|
||||
return [
|
||||
face
|
||||
for i, face in enumerate(filtered_faces)
|
||||
if i in filtering_options.faces_index
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -391,7 +426,7 @@ def get_or_default(l: List[Any], index: int, default: Any) -> Any:
|
||||
return l[index] if index < len(l) else default
|
||||
|
||||
|
||||
def get_faces_from_img_files(files: List[gr.File]) -> List[Optional[CV2ImgU8]]:
|
||||
def get_faces_from_img_files(files: List[str]) -> List[Optional[CV2ImgU8]]:
|
||||
"""
|
||||
Extracts faces from a list of image files.
|
||||
|
||||
@@ -407,7 +442,7 @@ def get_faces_from_img_files(files: List[gr.File]) -> List[Optional[CV2ImgU8]]:
|
||||
|
||||
if len(files) > 0:
|
||||
for file in files:
|
||||
img = Image.open(file.name) # Open the image file
|
||||
img = Image.open(file) # Open the image file
|
||||
face = get_or_default(
|
||||
get_faces(pil_to_cv2(img)), 0, None
|
||||
) # Extract faces from the image
|
||||
@@ -503,41 +538,44 @@ def swap_face(
|
||||
result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
|
||||
return_result.image = result_image
|
||||
|
||||
# FIXME : recompute similarity
|
||||
|
||||
# if compute_similarity:
|
||||
# try:
|
||||
# result_faces = get_faces(
|
||||
# cv2.cvtColor(np.array(result_image), cv2.COLOR_RGB2BGR),
|
||||
# sort_by_face_size=sort_by_face_size,
|
||||
# )
|
||||
# if same_gender:
|
||||
# result_faces = [
|
||||
# x for x in result_faces if x["gender"] == gender
|
||||
# ]
|
||||
|
||||
# for i, swapped_face in enumerate(result_faces):
|
||||
# logger.info(f"compare face {i}")
|
||||
# if i in faces_index and i < len(target_faces):
|
||||
# return_result.similarity[i] = cosine_similarity_face(
|
||||
# source_face, swapped_face
|
||||
# )
|
||||
# return_result.ref_similarity[i] = cosine_similarity_face(
|
||||
# reference_face, swapped_face
|
||||
# )
|
||||
|
||||
# logger.info(f"similarity {return_result.similarity}")
|
||||
# logger.info(f"ref similarity {return_result.ref_similarity}")
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error("Similarity processing failed %s", e)
|
||||
# raise e
|
||||
except Exception as e:
|
||||
logger.error("Conversion failed %s", e)
|
||||
raise e
|
||||
return return_result
|
||||
|
||||
|
||||
def compute_similarity(
|
||||
reference_face: Face,
|
||||
source_face: Face,
|
||||
swapped_image: PILImage,
|
||||
filtering: FaceFilteringOptions,
|
||||
) -> Tuple[Dict[int, float], Dict[int, float]]:
|
||||
similarity: Dict[int, float] = {}
|
||||
ref_similarity: Dict[int, float] = {}
|
||||
try:
|
||||
swapped_image_cv2: CV2ImgU8 = cv2.cvtColor(
|
||||
np.array(swapped_image), cv2.COLOR_RGB2BGR
|
||||
)
|
||||
new_faces = filter_faces(get_faces(swapped_image_cv2), filtering)
|
||||
if len(new_faces) == 0:
|
||||
logger.error("compute_similarity : No faces to compare with !")
|
||||
return None
|
||||
|
||||
for i, swapped_face in enumerate(new_faces):
|
||||
logger.info(f"compare face {i}")
|
||||
similarity[i] = cosine_similarity_face(source_face, swapped_face)
|
||||
ref_similarity[i] = cosine_similarity_face(reference_face, swapped_face)
|
||||
|
||||
logger.info(f"similarity {similarity}")
|
||||
logger.info(f"ref similarity {ref_similarity}")
|
||||
|
||||
return (similarity, ref_similarity)
|
||||
except Exception as e:
|
||||
logger.error("Similarity processing failed %s", e)
|
||||
raise e
|
||||
return None
|
||||
|
||||
|
||||
def process_image_unit(
|
||||
model: str,
|
||||
unit: FaceSwapUnitSettings,
|
||||
@@ -580,13 +618,14 @@ def process_image_unit(
|
||||
logger.info("Use source face as reference face")
|
||||
reference_face = src_face
|
||||
|
||||
target_faces = filter_faces(
|
||||
faces,
|
||||
face_filtering_options = FaceFilteringOptions(
|
||||
faces_index=unit.faces_index,
|
||||
source_gender=src_face["gender"] if unit.same_gender else None,
|
||||
sort_by_face_size=unit.sort_by_size,
|
||||
)
|
||||
|
||||
target_faces = filter_faces(faces, filtering_options=face_filtering_options)
|
||||
|
||||
# Apply pre-inpainting to image
|
||||
if unit.pre_inpainting.inpainting_denoising_strengh > 0:
|
||||
current_image = img2img_diffusion(
|
||||
@@ -611,6 +650,18 @@ def process_image_unit(
|
||||
|
||||
save_img_debug(result.image, "After swap")
|
||||
|
||||
if unit.compute_similarity:
|
||||
similarities = compute_similarity(
|
||||
reference_face=reference_face,
|
||||
source_face=src_face,
|
||||
swapped_image=result.image,
|
||||
filtering=face_filtering_options,
|
||||
)
|
||||
if similarities:
|
||||
(result.similarity, result.ref_similarity) = similarities
|
||||
else:
|
||||
logger.error("Failed to compute similarity")
|
||||
|
||||
if result.image is None:
|
||||
logger.error("Result image is None")
|
||||
if (
|
||||
|
||||
@@ -1,26 +1,21 @@
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
from pprint import pformat, pprint
|
||||
from pprint import pformat
|
||||
from typing import *
|
||||
from scripts.faceswaplab_utils.typing import *
|
||||
import gradio as gr
|
||||
import modules.scripts as scripts
|
||||
import onnx
|
||||
import pandas as pd
|
||||
from modules import scripts
|
||||
from modules.shared import opts
|
||||
from PIL import Image
|
||||
|
||||
import scripts.faceswaplab_swapping.swapper as swapper
|
||||
from scripts.faceswaplab_postprocessing.postprocessing import enhance_image
|
||||
from scripts.faceswaplab_postprocessing.postprocessing_options import (
|
||||
PostProcessingOptions,
|
||||
)
|
||||
from scripts.faceswaplab_ui.faceswaplab_postprocessing_ui import postprocessing_ui
|
||||
from scripts.faceswaplab_ui.faceswaplab_unit_settings import FaceSwapUnitSettings
|
||||
from scripts.faceswaplab_ui.faceswaplab_unit_ui import faceswap_unit_ui
|
||||
from scripts.faceswaplab_utils import face_utils, imgutils
|
||||
from scripts.faceswaplab_utils import face_checkpoints_utils, imgutils
|
||||
from scripts.faceswaplab_utils.faceswaplab_logging import logger
|
||||
from scripts.faceswaplab_utils.models_utils import get_models
|
||||
from scripts.faceswaplab_utils.ui_utils import dataclasses_from_flat_list
|
||||
@@ -138,24 +133,9 @@ def analyse_faces(image: PILImage, det_threshold: float = 0.5) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def sanitize_name(name: str) -> str:
|
||||
"""
|
||||
Sanitize the input name by removing special characters and replacing spaces with underscores.
|
||||
|
||||
Parameters:
|
||||
name (str): The input name to be sanitized.
|
||||
|
||||
Returns:
|
||||
str: The sanitized name with special characters removed and spaces replaced by underscores.
|
||||
"""
|
||||
name = re.sub("[^A-Za-z0-9_. ]+", "", name)
|
||||
name = name.replace(" ", "_")
|
||||
return name[:255]
|
||||
|
||||
|
||||
def build_face_checkpoint_and_save(
|
||||
batch_files: gr.File, name: str
|
||||
) -> Optional[PILImage]:
|
||||
batch_files: gr.File, name: str, overwrite: bool
|
||||
) -> PILImage:
|
||||
"""
|
||||
Builds a face checkpoint using the provided image files, performs face swapping,
|
||||
and saves the result to a file. If a blended face is successfully obtained and the face swapping
|
||||
@@ -170,79 +150,19 @@ def build_face_checkpoint_and_save(
|
||||
"""
|
||||
|
||||
try:
|
||||
name = sanitize_name(name)
|
||||
batch_files = batch_files or []
|
||||
logger.info("Build %s %s", name, [x.name for x in batch_files])
|
||||
faces = swapper.get_faces_from_img_files(batch_files)
|
||||
blended_face = swapper.blend_faces(faces)
|
||||
preview_path = os.path.join(
|
||||
scripts.basedir(), "extensions", "sd-webui-faceswaplab", "references"
|
||||
if not batch_files:
|
||||
logger.error("No face found")
|
||||
return None
|
||||
filenames = [x.name for x in batch_files]
|
||||
preview_image = face_checkpoints_utils.build_face_checkpoint_and_save(
|
||||
filenames, name, overwrite=overwrite
|
||||
)
|
||||
|
||||
faces_path = os.path.join(scripts.basedir(), "models", "faceswaplab", "faces")
|
||||
|
||||
os.makedirs(faces_path, exist_ok=True)
|
||||
|
||||
target_img: PILImage = None
|
||||
if blended_face:
|
||||
if blended_face["gender"] == 0:
|
||||
target_img = Image.open(os.path.join(preview_path, "woman.png"))
|
||||
else:
|
||||
target_img = Image.open(os.path.join(preview_path, "man.png"))
|
||||
|
||||
if name == "":
|
||||
name = "default_name"
|
||||
pprint(blended_face)
|
||||
target_face = swapper.get_or_default(
|
||||
swapper.get_faces(imgutils.pil_to_cv2(target_img)), 0, None
|
||||
)
|
||||
if target_face is None:
|
||||
logger.error(
|
||||
"Failed to open reference image, cannot create preview : That should not happen unless you deleted the references folder or change the detection threshold."
|
||||
)
|
||||
else:
|
||||
result = swapper.swap_face(
|
||||
reference_face=blended_face,
|
||||
target_faces=[target_face],
|
||||
source_face=blended_face,
|
||||
target_img=target_img,
|
||||
model=get_models()[0],
|
||||
upscaled_swapper=opts.data.get(
|
||||
"faceswaplab_upscaled_swapper", False
|
||||
),
|
||||
)
|
||||
result_image = enhance_image(
|
||||
result.image,
|
||||
PostProcessingOptions(
|
||||
face_restorer_name="CodeFormer", restorer_visibility=1
|
||||
),
|
||||
)
|
||||
|
||||
file_path = os.path.join(faces_path, f"{name}.safetensors")
|
||||
file_number = 1
|
||||
while os.path.exists(file_path):
|
||||
file_path = os.path.join(
|
||||
faces_path, f"{name}_{file_number}.safetensors"
|
||||
)
|
||||
file_number += 1
|
||||
result_image.save(file_path + ".png")
|
||||
|
||||
face_utils.save_face(filename=file_path, face=blended_face)
|
||||
try:
|
||||
data = face_utils.load_face(filename=file_path)
|
||||
logger.debug(data)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return result_image
|
||||
|
||||
logger.error("No face found")
|
||||
except Exception as e:
|
||||
logger.error("Failed to build checkpoint %s", e)
|
||||
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
return target_img
|
||||
return preview_image
|
||||
|
||||
|
||||
def explore_onnx_faceswap_model(model_path: str) -> pd.DataFrame:
|
||||
@@ -281,7 +201,7 @@ def explore_onnx_faceswap_model(model_path: str) -> pd.DataFrame:
|
||||
|
||||
def batch_process(
|
||||
files: List[gr.File], save_path: str, *components: Tuple[Any, ...]
|
||||
) -> Optional[List[PILImage]]:
|
||||
) -> List[PILImage]:
|
||||
try:
|
||||
units_count = opts.data.get("faceswaplab_units_count", 3)
|
||||
|
||||
@@ -308,7 +228,7 @@ def batch_process(
|
||||
logger.error("Batch Process error : %s", e)
|
||||
|
||||
traceback.print_exc()
|
||||
return None
|
||||
return []
|
||||
|
||||
|
||||
def tools_ui() -> None:
|
||||
@@ -319,7 +239,7 @@ def tools_ui() -> None:
|
||||
"""Build a face based on a batch list of images. Will blend the resulting face and store the checkpoint in the faceswaplab/faces directory."""
|
||||
)
|
||||
with gr.Row():
|
||||
batch_files = gr.components.File(
|
||||
build_batch_files = gr.components.File(
|
||||
type="file",
|
||||
file_count="multiple",
|
||||
label="Batch Sources Images",
|
||||
@@ -332,12 +252,18 @@ def tools_ui() -> None:
|
||||
interactive=False,
|
||||
elem_id="faceswaplab_build_preview_face",
|
||||
)
|
||||
name = gr.Textbox(
|
||||
build_name = gr.Textbox(
|
||||
value="Face",
|
||||
placeholder="Name of the character",
|
||||
label="Name of the character",
|
||||
elem_id="faceswaplab_build_character_name",
|
||||
)
|
||||
build_overwrite = gr.Checkbox(
|
||||
False,
|
||||
placeholder="overwrite",
|
||||
label="Overwrite Checkpoint if exist (else will add number)",
|
||||
elem_id="faceswaplab_build_overwrite",
|
||||
)
|
||||
generate_checkpoint_btn = gr.Button(
|
||||
"Save", elem_id="faceswaplab_build_save_btn"
|
||||
)
|
||||
@@ -452,7 +378,9 @@ def tools_ui() -> None:
|
||||
)
|
||||
compare_btn.click(compare, inputs=[img1, img2], outputs=[compare_result_text])
|
||||
generate_checkpoint_btn.click(
|
||||
build_face_checkpoint_and_save, inputs=[batch_files, name], outputs=[preview]
|
||||
build_face_checkpoint_and_save,
|
||||
inputs=[build_batch_files, build_name, build_overwrite],
|
||||
outputs=[preview],
|
||||
)
|
||||
extract_btn.click(
|
||||
extract_faces,
|
||||
|
||||
@@ -8,7 +8,7 @@ from insightface.app.common import Face
|
||||
from PIL import Image
|
||||
from scripts.faceswaplab_utils.imgutils import pil_to_cv2
|
||||
from scripts.faceswaplab_utils.faceswaplab_logging import logger
|
||||
from scripts.faceswaplab_utils import face_utils
|
||||
from scripts.faceswaplab_utils import face_checkpoints_utils
|
||||
from scripts.faceswaplab_inpainting.faceswaplab_inpainting import InpaintingOptions
|
||||
from client_api import api_utils
|
||||
|
||||
@@ -118,14 +118,13 @@ class FaceSwapUnitSettings:
|
||||
"""
|
||||
if not hasattr(self, "_reference_face"):
|
||||
if self.source_face and self.source_face != "None":
|
||||
with open(self.source_face, "rb") as file:
|
||||
try:
|
||||
logger.info(f"loading face {file.name}")
|
||||
face = face_utils.load_face(file.name)
|
||||
self._reference_face = face
|
||||
except Exception as e:
|
||||
logger.error("Failed to load checkpoint : %s", e)
|
||||
raise e
|
||||
try:
|
||||
logger.info(f"loading face {self.source_face}")
|
||||
face = face_checkpoints_utils.load_face(self.source_face)
|
||||
self._reference_face = face
|
||||
except Exception as e:
|
||||
logger.error("Failed to load checkpoint : %s", e)
|
||||
raise e
|
||||
elif self.source_img is not None:
|
||||
if isinstance(self.source_img, str): # source_img is a base64 string
|
||||
if (
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import List
|
||||
from scripts.faceswaplab_ui.faceswaplab_inpainting_ui import face_inpainting_ui
|
||||
from scripts.faceswaplab_utils.face_utils import get_face_checkpoints
|
||||
from scripts.faceswaplab_utils.face_checkpoints_utils import get_face_checkpoints
|
||||
import gradio as gr
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,236 @@
|
||||
import glob
|
||||
import os
|
||||
from typing import *
|
||||
from insightface.app.common import Face
|
||||
from safetensors.torch import save_file, safe_open
|
||||
import torch
|
||||
|
||||
import modules.scripts as scripts
|
||||
from modules import scripts
|
||||
from scripts.faceswaplab_utils.faceswaplab_logging import logger
|
||||
from scripts.faceswaplab_utils.typing import *
|
||||
from scripts.faceswaplab_utils import imgutils
|
||||
from scripts.faceswaplab_postprocessing.postprocessing import enhance_image
|
||||
from scripts.faceswaplab_postprocessing.postprocessing_options import (
|
||||
PostProcessingOptions,
|
||||
)
|
||||
from scripts.faceswaplab_utils.models_utils import get_models
|
||||
from modules.shared import opts
|
||||
import traceback
|
||||
|
||||
import dill as pickle # will be removed in future versions
|
||||
from scripts.faceswaplab_swapping import swapper
|
||||
from pprint import pformat
|
||||
import re
|
||||
|
||||
|
||||
def sanitize_name(name: str) -> str:
|
||||
"""
|
||||
Sanitize the input name by removing special characters and replacing spaces with underscores.
|
||||
|
||||
Parameters:
|
||||
name (str): The input name to be sanitized.
|
||||
|
||||
Returns:
|
||||
str: The sanitized name with special characters removed and spaces replaced by underscores.
|
||||
"""
|
||||
name = re.sub("[^A-Za-z0-9_. ]+", "", name)
|
||||
name = name.replace(" ", "_")
|
||||
return name[:255]
|
||||
|
||||
|
||||
def build_face_checkpoint_and_save(
|
||||
batch_files: List[str], name: str, overwrite: bool = False
|
||||
) -> PILImage:
|
||||
"""
|
||||
Builds a face checkpoint using the provided image files, performs face swapping,
|
||||
and saves the result to a file. If a blended face is successfully obtained and the face swapping
|
||||
process succeeds, the resulting image is returned. Otherwise, None is returned.
|
||||
|
||||
Args:
|
||||
batch_files (list): List of image file paths used to create the face checkpoint.
|
||||
name (str): The name assigned to the face checkpoint.
|
||||
|
||||
Returns:
|
||||
PIL.PILImage or None: The resulting swapped face image if the process is successful; None otherwise.
|
||||
"""
|
||||
|
||||
try:
|
||||
name = sanitize_name(name)
|
||||
batch_files = batch_files or []
|
||||
logger.info("Build %s %s", name, [x for x in batch_files])
|
||||
faces = swapper.get_faces_from_img_files(batch_files)
|
||||
blended_face = swapper.blend_faces(faces)
|
||||
preview_path = os.path.join(
|
||||
scripts.basedir(), "extensions", "sd-webui-faceswaplab", "references"
|
||||
)
|
||||
|
||||
reference_preview_img: PILImage = None
|
||||
if blended_face:
|
||||
if blended_face["gender"] == 0:
|
||||
reference_preview_img = Image.open(
|
||||
os.path.join(preview_path, "woman.png")
|
||||
)
|
||||
else:
|
||||
reference_preview_img = Image.open(
|
||||
os.path.join(preview_path, "man.png")
|
||||
)
|
||||
|
||||
if name == "":
|
||||
name = "default_name"
|
||||
logger.debug("Face %s", pformat(blended_face))
|
||||
target_face = swapper.get_or_default(
|
||||
swapper.get_faces(imgutils.pil_to_cv2(reference_preview_img)), 0, None
|
||||
)
|
||||
if target_face is None:
|
||||
logger.error(
|
||||
"Failed to open reference image, cannot create preview : That should not happen unless you deleted the references folder or change the detection threshold."
|
||||
)
|
||||
else:
|
||||
result = swapper.swap_face(
|
||||
reference_face=blended_face,
|
||||
target_faces=[target_face],
|
||||
source_face=blended_face,
|
||||
target_img=reference_preview_img,
|
||||
model=get_models()[0],
|
||||
upscaled_swapper=opts.data.get(
|
||||
"faceswaplab_upscaled_swapper", False
|
||||
),
|
||||
)
|
||||
preview_image = enhance_image(
|
||||
result.image,
|
||||
PostProcessingOptions(
|
||||
face_restorer_name="CodeFormer", restorer_visibility=1
|
||||
),
|
||||
)
|
||||
|
||||
file_path = os.path.join(get_checkpoint_path(), f"{name}.safetensors")
|
||||
if not overwrite:
|
||||
file_number = 1
|
||||
while os.path.exists(file_path):
|
||||
file_path = os.path.join(
|
||||
get_checkpoint_path(), f"{name}_{file_number}.safetensors"
|
||||
)
|
||||
file_number += 1
|
||||
save_face(filename=file_path, face=blended_face)
|
||||
preview_image.save(file_path + ".png")
|
||||
try:
|
||||
data = load_face(file_path)
|
||||
logger.debug(data)
|
||||
except Exception as e:
|
||||
logger.error("Error loading checkpoint, after creation %s", e)
|
||||
traceback.print_exc()
|
||||
|
||||
return preview_image
|
||||
|
||||
else:
|
||||
logger.error("No face found")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error("Failed to build checkpoint %s", e)
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
def save_face(face: Face, filename: str) -> None:
|
||||
try:
|
||||
tensors = {
|
||||
"embedding": torch.tensor(face["embedding"]),
|
||||
"gender": torch.tensor(face["gender"]),
|
||||
"age": torch.tensor(face["age"]),
|
||||
}
|
||||
save_file(tensors, filename)
|
||||
except Exception as e:
|
||||
traceback.print_exc
|
||||
logger.error("Failed to save checkpoint %s", e)
|
||||
raise e
|
||||
|
||||
|
||||
def load_face(name: str) -> Face:
|
||||
filename = matching_checkpoint(name)
|
||||
if filename is None:
|
||||
return None
|
||||
|
||||
if filename.endswith(".pkl"):
|
||||
logger.warning(
|
||||
"Pkl files for faces are deprecated to enhance safety, they will be unsupported in future versions."
|
||||
)
|
||||
logger.warning("The file will be converted to .safetensors")
|
||||
logger.warning(
|
||||
"You can also use this script https://gist.github.com/glucauze/4a3c458541f2278ad801f6625e5b9d3d"
|
||||
)
|
||||
with open(filename, "rb") as file:
|
||||
logger.info("Load pkl")
|
||||
face = Face(pickle.load(file))
|
||||
logger.warning(
|
||||
"Convert to safetensors, you can remove the pkl version once you have ensured that the safetensor is working"
|
||||
)
|
||||
save_face(face, filename.replace(".pkl", ".safetensors"))
|
||||
return face
|
||||
|
||||
elif filename.endswith(".safetensors"):
|
||||
face = {}
|
||||
with safe_open(filename, framework="pt", device="cpu") as f:
|
||||
for k in f.keys():
|
||||
logger.debug("load key %s", k)
|
||||
face[k] = f.get_tensor(k).numpy()
|
||||
return Face(face)
|
||||
|
||||
raise NotImplementedError("Unknown file type, face extraction not implemented")
|
||||
|
||||
|
||||
def get_checkpoint_path() -> str:
|
||||
checkpoint_path = os.path.join(scripts.basedir(), "models", "faceswaplab", "faces")
|
||||
os.makedirs(checkpoint_path, exist_ok=True)
|
||||
return checkpoint_path
|
||||
|
||||
|
||||
def matching_checkpoint(name: str) -> Optional[str]:
|
||||
"""
|
||||
Retrieve the full path of a checkpoint file matching the given name.
|
||||
|
||||
If the name already includes a path separator, it is returned as-is. Otherwise, the function looks for a matching
|
||||
file with the extensions ".safetensors" or ".pkl" in the checkpoint directory.
|
||||
|
||||
Args:
|
||||
name (str): The name or path of the checkpoint file.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The full path of the matching checkpoint file, or None if no match is found.
|
||||
"""
|
||||
|
||||
# If the name already includes a path separator, return it as is
|
||||
if os.path.sep in name:
|
||||
return name
|
||||
|
||||
# If the name doesn't end with the specified extensions, look for a matching file
|
||||
if not (name.endswith(".safetensors") or name.endswith(".pkl")):
|
||||
# Try appending each extension and check if the file exists in the checkpoint path
|
||||
for ext in [".safetensors", ".pkl"]:
|
||||
full_path = os.path.join(get_checkpoint_path(), name + ext)
|
||||
if os.path.exists(full_path):
|
||||
return full_path
|
||||
# If no matching file is found, return None
|
||||
return None
|
||||
|
||||
# If the name already ends with the specified extensions, simply complete the path
|
||||
return os.path.join(get_checkpoint_path(), name)
|
||||
|
||||
|
||||
def get_face_checkpoints() -> List[str]:
|
||||
"""
|
||||
Retrieve a list of face checkpoint paths.
|
||||
|
||||
This function searches for face files with the extension ".safetensors" in the specified directory and returns a list
|
||||
containing the paths of those files.
|
||||
|
||||
Returns:
|
||||
list: A list of face paths, including the string "None" as the first element.
|
||||
"""
|
||||
faces_path = os.path.join(get_checkpoint_path(), "*.safetensors")
|
||||
faces = glob.glob(faces_path)
|
||||
|
||||
faces_path = os.path.join(get_checkpoint_path(), "*.pkl")
|
||||
faces += glob.glob(faces_path)
|
||||
|
||||
return ["None"] + [os.path.basename(face) for face in sorted(faces)]
|
||||
@@ -1,72 +0,0 @@
|
||||
import glob
|
||||
import os
|
||||
from typing import List
|
||||
from insightface.app.common import Face
|
||||
from safetensors.torch import save_file, safe_open
|
||||
import torch
|
||||
|
||||
import modules.scripts as scripts
|
||||
from modules import scripts
|
||||
from scripts.faceswaplab_utils.faceswaplab_logging import logger
|
||||
import dill as pickle # will be removed in future versions
|
||||
|
||||
|
||||
def save_face(face: Face, filename: str) -> None:
|
||||
tensors = {
|
||||
"embedding": torch.tensor(face["embedding"]),
|
||||
"gender": torch.tensor(face["gender"]),
|
||||
"age": torch.tensor(face["age"]),
|
||||
}
|
||||
save_file(tensors, filename)
|
||||
|
||||
|
||||
def load_face(filename: str) -> Face:
|
||||
if filename.endswith(".pkl"):
|
||||
logger.warning(
|
||||
"Pkl files for faces are deprecated to enhance safety, they will be unsupported in future versions."
|
||||
)
|
||||
logger.warning("The file will be converted to .safetensors")
|
||||
logger.warning(
|
||||
"You can also use this script https://gist.github.com/glucauze/4a3c458541f2278ad801f6625e5b9d3d"
|
||||
)
|
||||
with open(filename, "rb") as file:
|
||||
logger.info("Load pkl")
|
||||
face = Face(pickle.load(file))
|
||||
logger.warning(
|
||||
"Convert to safetensors, you can remove the pkl version once you have ensured that the safetensor is working"
|
||||
)
|
||||
save_face(face, filename.replace(".pkl", ".safetensors"))
|
||||
return face
|
||||
|
||||
elif filename.endswith(".safetensors"):
|
||||
face = {}
|
||||
with safe_open(filename, framework="pt", device="cpu") as f:
|
||||
for k in f.keys():
|
||||
logger.debug("load key %s", k)
|
||||
face[k] = f.get_tensor(k).numpy()
|
||||
return Face(face)
|
||||
|
||||
raise NotImplementedError("Unknown file type, face extraction not implemented")
|
||||
|
||||
|
||||
def get_face_checkpoints() -> List[str]:
|
||||
"""
|
||||
Retrieve a list of face checkpoint paths.
|
||||
|
||||
This function searches for face files with the extension ".safetensors" in the specified directory and returns a list
|
||||
containing the paths of those files.
|
||||
|
||||
Returns:
|
||||
list: A list of face paths, including the string "None" as the first element.
|
||||
"""
|
||||
faces_path = os.path.join(
|
||||
scripts.basedir(), "models", "faceswaplab", "faces", "*.safetensors"
|
||||
)
|
||||
faces = glob.glob(faces_path)
|
||||
|
||||
faces_path = os.path.join(
|
||||
scripts.basedir(), "models", "faceswaplab", "faces", "*.pkl"
|
||||
)
|
||||
faces += glob.glob(faces_path)
|
||||
|
||||
return ["None"] + sorted(faces)
|
||||
@@ -11,6 +11,7 @@ from modules import processing
|
||||
import base64
|
||||
from collections import Counter
|
||||
from scripts.faceswaplab_utils.typing import BoxCoords, CV2ImgU8, PILImage
|
||||
from scripts.faceswaplab_utils.faceswaplab_logging import logger
|
||||
|
||||
|
||||
def check_against_nsfw(img: PILImage) -> bool:
|
||||
@@ -157,19 +158,6 @@ def create_square_image(image_list: List[PILImage]) -> Optional[PILImage]:
|
||||
return None
|
||||
|
||||
|
||||
# def create_mask(image : PILImage, box_coords : Tuple[int, int, int, int]) -> PILImage:
|
||||
# width, height = image.size
|
||||
# mask = Image.new("L", (width, height), 255)
|
||||
# x1, y1, x2, y2 = box_coords
|
||||
# for x in range(width):
|
||||
# for y in range(height):
|
||||
# if x1 <= x <= x2 and y1 <= y <= y2:
|
||||
# mask.putpixel((x, y), 255)
|
||||
# else:
|
||||
# mask.putpixel((x, y), 0)
|
||||
# return mask
|
||||
|
||||
|
||||
def create_mask(
|
||||
image: PILImage,
|
||||
box_coords: BoxCoords,
|
||||
@@ -216,7 +204,9 @@ def apply_mask(
|
||||
if overlays is None or batch_index >= len(overlays):
|
||||
return img
|
||||
overlay: PILImage = overlays[batch_index]
|
||||
overlay = overlay.resize((img.size), resample=Image.Resampling.LANCZOS)
|
||||
logger.debug("Overlay size %s, Image size %s", overlay.size, img.size)
|
||||
if overlay.size != img.size:
|
||||
overlay = overlay.resize((img.size), resample=Image.Resampling.LANCZOS)
|
||||
img = img.copy()
|
||||
img.paste(overlay, (0, 0), overlay)
|
||||
return img
|
||||
|
||||
Reference in New Issue
Block a user