pkl to safetensors
This commit is contained in:
@@ -8,7 +8,7 @@ REFERENCE_PATH = os.path.join(
|
||||
scripts.basedir(), "extensions", "sd-webui-faceswaplab", "references"
|
||||
)
|
||||
|
||||
VERSION_FLAG: str = "v1.1.1"
|
||||
VERSION_FLAG: str = "v1.1.2"
|
||||
EXTENSION_PATH = os.path.join("extensions", "sd-webui-faceswaplab")
|
||||
|
||||
# The NSFW score threshold. If any part of the image has a score greater than this threshold, the image will be considered NSFW.
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
import os
|
||||
from pprint import pformat, pprint
|
||||
|
||||
import dill as pickle
|
||||
from scripts.faceswaplab_utils import face_utils
|
||||
import gradio as gr
|
||||
import modules.scripts as scripts
|
||||
import onnx
|
||||
import pandas as pd
|
||||
from scripts.faceswaplab_ui.faceswaplab_unit_ui import faceswap_unit_ui
|
||||
from scripts.faceswaplab_ui.faceswaplab_postprocessing_ui import postprocessing_ui
|
||||
from insightface.app.common import Face
|
||||
from modules import scripts
|
||||
from PIL import Image
|
||||
from modules.shared import opts
|
||||
@@ -128,10 +126,17 @@ def analyse_faces(image: Image.Image, det_threshold: float = 0.5) -> Optional[st
|
||||
|
||||
|
||||
def sanitize_name(name: str) -> str:
|
||||
logger.debug(f"Sanitize name {name}")
|
||||
"""
|
||||
Sanitize the input name by removing special characters and replacing spaces with underscores.
|
||||
|
||||
Parameters:
|
||||
name (str): The input name to be sanitized.
|
||||
|
||||
Returns:
|
||||
str: The sanitized name with special characters removed and spaces replaced by underscores.
|
||||
"""
|
||||
name = re.sub("[^A-Za-z0-9_. ]+", "", name)
|
||||
name = name.replace(" ", "_")
|
||||
logger.debug(f"Sanitized name {name[:255]}")
|
||||
return name[:255]
|
||||
|
||||
|
||||
@@ -185,25 +190,19 @@ def build_face_checkpoint_and_save(
|
||||
),
|
||||
)
|
||||
|
||||
file_path = os.path.join(faces_path, f"{name}.pkl")
|
||||
file_path = os.path.join(faces_path, f"{name}.safetensors")
|
||||
file_number = 1
|
||||
while os.path.exists(file_path):
|
||||
file_path = os.path.join(faces_path, f"{name}_{file_number}.pkl")
|
||||
file_path = os.path.join(
|
||||
faces_path, f"{name}_{file_number}.safetensors"
|
||||
)
|
||||
file_number += 1
|
||||
result_image.save(file_path + ".png")
|
||||
with open(file_path, "wb") as file:
|
||||
pickle.dump(
|
||||
{
|
||||
"embedding": blended_face.embedding,
|
||||
"gender": blended_face.gender,
|
||||
"age": blended_face.age,
|
||||
},
|
||||
file,
|
||||
)
|
||||
|
||||
face_utils.save_face(filename=file_path, face=blended_face)
|
||||
try:
|
||||
with open(file_path, "rb") as file:
|
||||
data = Face(pickle.load(file))
|
||||
print(data)
|
||||
data = face_utils.load_face(filename=file_path)
|
||||
print(data)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return result_image
|
||||
|
||||
@@ -4,12 +4,12 @@ import base64
|
||||
import io
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Any, List, Optional, Set, Union
|
||||
import dill as pickle
|
||||
import gradio as gr
|
||||
from insightface.app.common import Face
|
||||
from PIL import Image
|
||||
from scripts.faceswaplab_utils.imgutils import pil_to_cv2
|
||||
from scripts.faceswaplab_utils.faceswaplab_logging import logger
|
||||
from scripts.faceswaplab_utils import face_utils
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -94,8 +94,8 @@ class FaceSwapUnitSettings:
|
||||
if self.source_face and self.source_face != "None":
|
||||
with open(self.source_face, "rb") as file:
|
||||
try:
|
||||
logger.info(f"loading pickle {file.name}")
|
||||
face = Face(pickle.load(file))
|
||||
logger.info(f"loading face {file.name}")
|
||||
face = face_utils.load_face(file.name)
|
||||
self._reference_face = face
|
||||
except Exception as e:
|
||||
logger.error("Failed to load checkpoint : %s", e)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import List
|
||||
from scripts.faceswaplab_utils.models_utils import get_face_checkpoints
|
||||
from scripts.faceswaplab_utils.face_utils import get_face_checkpoints
|
||||
import gradio as gr
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
import glob
|
||||
import os
|
||||
from typing import List
|
||||
from insightface.app.common import Face
|
||||
from safetensors.torch import save_file, safe_open
|
||||
import torch
|
||||
|
||||
import modules.scripts as scripts
|
||||
from modules import scripts
|
||||
from scripts.faceswaplab_utils.faceswaplab_logging import logger
|
||||
|
||||
|
||||
def save_face(face: Face, filename: str) -> None:
|
||||
tensors = {
|
||||
"embedding": torch.tensor(face["embedding"]),
|
||||
"gender": torch.tensor(face["gender"]),
|
||||
"age": torch.tensor(face["age"]),
|
||||
}
|
||||
save_file(tensors, filename)
|
||||
|
||||
|
||||
def load_face(filename: str) -> Face:
|
||||
face = {}
|
||||
logger.debug("Try to load face from %s", filename)
|
||||
with safe_open(filename, framework="pt", device="cpu") as f:
|
||||
logger.debug("File contains %s keys", f.keys())
|
||||
for k in f.keys():
|
||||
logger.debug("load key %s", k)
|
||||
face[k] = f.get_tensor(k).numpy()
|
||||
logger.debug("face : %s", face)
|
||||
return Face(face)
|
||||
|
||||
|
||||
def get_face_checkpoints() -> List[str]:
|
||||
"""
|
||||
Retrieve a list of face checkpoint paths.
|
||||
|
||||
This function searches for face files with the extension ".safetensors" in the specified directory and returns a list
|
||||
containing the paths of those files.
|
||||
|
||||
Returns:
|
||||
list: A list of face paths, including the string "None" as the first element.
|
||||
"""
|
||||
faces_path = os.path.join(
|
||||
scripts.basedir(), "models", "faceswaplab", "faces", "*.safetensors"
|
||||
)
|
||||
faces = glob.glob(faces_path)
|
||||
return ["None"] + faces
|
||||
@@ -43,20 +43,3 @@ def get_current_model() -> str:
|
||||
"No faceswap model found. Please add it to the faceswaplab directory."
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def get_face_checkpoints() -> List[str]:
|
||||
"""
|
||||
Retrieve a list of face checkpoint paths.
|
||||
|
||||
This function searches for face files with the extension ".pkl" in the specified directory and returns a list
|
||||
containing the paths of those files.
|
||||
|
||||
Returns:
|
||||
list: A list of face paths, including the string "None" as the first element.
|
||||
"""
|
||||
faces_path = os.path.join(
|
||||
scripts.basedir(), "models", "faceswaplab", "faces", "*.pkl"
|
||||
)
|
||||
faces = glob.glob(faces_path)
|
||||
return ["None"] + faces
|
||||
|
||||
Reference in New Issue
Block a user