pkl to safetensors

This commit is contained in:
Tran Xen
2023-07-30 00:55:17 +02:00
parent be02fdcd7d
commit 31635d369f
11 changed files with 94 additions and 44 deletions
+1 -1
View File
@@ -8,7 +8,7 @@ REFERENCE_PATH = os.path.join(
scripts.basedir(), "extensions", "sd-webui-faceswaplab", "references"
)
VERSION_FLAG: str = "v1.1.1"
VERSION_FLAG: str = "v1.1.2"
EXTENSION_PATH = os.path.join("extensions", "sd-webui-faceswaplab")
# The NSFW score threshold. If any part of the image has a score greater than this threshold, the image will be considered NSFW.
+18 -19
View File
@@ -1,14 +1,12 @@
import os
from pprint import pformat, pprint
import dill as pickle
from scripts.faceswaplab_utils import face_utils
import gradio as gr
import modules.scripts as scripts
import onnx
import pandas as pd
from scripts.faceswaplab_ui.faceswaplab_unit_ui import faceswap_unit_ui
from scripts.faceswaplab_ui.faceswaplab_postprocessing_ui import postprocessing_ui
from insightface.app.common import Face
from modules import scripts
from PIL import Image
from modules.shared import opts
@@ -128,10 +126,17 @@ def analyse_faces(image: Image.Image, det_threshold: float = 0.5) -> Optional[st
def sanitize_name(name: str) -> str:
logger.debug(f"Sanitize name {name}")
"""
Sanitize the input name by removing special characters and replacing spaces with underscores.
Parameters:
name (str): The input name to be sanitized.
Returns:
str: The sanitized name with special characters removed and spaces replaced by underscores.
"""
name = re.sub("[^A-Za-z0-9_. ]+", "", name)
name = name.replace(" ", "_")
logger.debug(f"Sanitized name {name[:255]}")
return name[:255]
@@ -185,25 +190,19 @@ def build_face_checkpoint_and_save(
),
)
file_path = os.path.join(faces_path, f"{name}.pkl")
file_path = os.path.join(faces_path, f"{name}.safetensors")
file_number = 1
while os.path.exists(file_path):
file_path = os.path.join(faces_path, f"{name}_{file_number}.pkl")
file_path = os.path.join(
faces_path, f"{name}_{file_number}.safetensors"
)
file_number += 1
result_image.save(file_path + ".png")
with open(file_path, "wb") as file:
pickle.dump(
{
"embedding": blended_face.embedding,
"gender": blended_face.gender,
"age": blended_face.age,
},
file,
)
face_utils.save_face(filename=file_path, face=blended_face)
try:
with open(file_path, "rb") as file:
data = Face(pickle.load(file))
print(data)
data = face_utils.load_face(filename=file_path)
print(data)
except Exception as e:
print(e)
return result_image
@@ -4,12 +4,12 @@ import base64
import io
from dataclasses import dataclass, fields
from typing import Any, List, Optional, Set, Union
import dill as pickle
import gradio as gr
from insightface.app.common import Face
from PIL import Image
from scripts.faceswaplab_utils.imgutils import pil_to_cv2
from scripts.faceswaplab_utils.faceswaplab_logging import logger
from scripts.faceswaplab_utils import face_utils
@dataclass
@@ -94,8 +94,8 @@ class FaceSwapUnitSettings:
if self.source_face and self.source_face != "None":
with open(self.source_face, "rb") as file:
try:
logger.info(f"loading pickle {file.name}")
face = Face(pickle.load(file))
logger.info(f"loading face {file.name}")
face = face_utils.load_face(file.name)
self._reference_face = face
except Exception as e:
logger.error("Failed to load checkpoint : %s", e)
@@ -1,5 +1,5 @@
from typing import List
from scripts.faceswaplab_utils.models_utils import get_face_checkpoints
from scripts.faceswaplab_utils.face_utils import get_face_checkpoints
import gradio as gr
+48
View File
@@ -0,0 +1,48 @@
import glob
import os
from typing import List
from insightface.app.common import Face
from safetensors.torch import save_file, safe_open
import torch
import modules.scripts as scripts
from modules import scripts
from scripts.faceswaplab_utils.faceswaplab_logging import logger
def save_face(face: Face, filename: str) -> None:
tensors = {
"embedding": torch.tensor(face["embedding"]),
"gender": torch.tensor(face["gender"]),
"age": torch.tensor(face["age"]),
}
save_file(tensors, filename)
def load_face(filename: str) -> Face:
face = {}
logger.debug("Try to load face from %s", filename)
with safe_open(filename, framework="pt", device="cpu") as f:
logger.debug("File contains %s keys", f.keys())
for k in f.keys():
logger.debug("load key %s", k)
face[k] = f.get_tensor(k).numpy()
logger.debug("face : %s", face)
return Face(face)
def get_face_checkpoints() -> List[str]:
"""
Retrieve a list of face checkpoint paths.
This function searches for face files with the extension ".safetensors" in the specified directory and returns a list
containing the paths of those files.
Returns:
list: A list of face paths, including the string "None" as the first element.
"""
faces_path = os.path.join(
scripts.basedir(), "models", "faceswaplab", "faces", "*.safetensors"
)
faces = glob.glob(faces_path)
return ["None"] + faces
-17
View File
@@ -43,20 +43,3 @@ def get_current_model() -> str:
"No faceswap model found. Please add it to the faceswaplab directory."
)
return model
def get_face_checkpoints() -> List[str]:
"""
Retrieve a list of face checkpoint paths.
This function searches for face files with the extension ".pkl" in the specified directory and returns a list
containing the paths of those files.
Returns:
list: A list of face paths, including the string "None" as the first element.
"""
faces_path = os.path.join(
scripts.basedir(), "models", "faceswaplab", "faces", "*.pkl"
)
faces = glob.glob(faces_path)
return ["None"] + faces