auto convert pkl

This commit is contained in:
Tran Xen
2023-07-30 13:49:24 +02:00
parent 31635d369f
commit 7538c724ef
6 changed files with 49 additions and 31 deletions
+34 -10
View File
@@ -8,6 +8,7 @@ import torch
import modules.scripts as scripts
from modules import scripts
from scripts.faceswaplab_utils.faceswaplab_logging import logger
import dill as pickle # will be removed in future versions
def save_face(face: Face, filename: str) -> None:
@@ -20,15 +21,32 @@ def save_face(face: Face, filename: str) -> None:
def load_face(filename: str) -> Face:
face = {}
logger.debug("Try to load face from %s", filename)
with safe_open(filename, framework="pt", device="cpu") as f:
logger.debug("File contains %s keys", f.keys())
for k in f.keys():
logger.debug("load key %s", k)
face[k] = f.get_tensor(k).numpy()
logger.debug("face : %s", face)
return Face(face)
if filename.endswith(".pkl"):
logger.warning(
"Pkl files for faces are deprecated to enhance safety, they will be unsupported in future versions."
)
logger.warning("The file will be converted to .safetensors")
logger.warning(
"You can also use this script https://gist.github.com/glucauze/4a3c458541f2278ad801f6625e5b9d3d"
)
with open(filename, "rb") as file:
logger.info("Load pkl")
face = Face(pickle.load(file))
logger.warning(
"Convert to safetensors, you can remove the pkl version once you have ensured that the safetensor is working"
)
save_face(face, filename.replace(".pkl", ".safetensors"))
return face
elif filename.endswith(".safetensors"):
face = {}
with safe_open(filename, framework="pt", device="cpu") as f:
for k in f.keys():
logger.debug("load key %s", k)
face[k] = f.get_tensor(k).numpy()
return Face(face)
raise NotImplementedError("Unknown file type, face extraction not implemented")
def get_face_checkpoints() -> List[str]:
@@ -45,4 +63,10 @@ def get_face_checkpoints() -> List[str]:
scripts.basedir(), "models", "faceswaplab", "faces", "*.safetensors"
)
faces = glob.glob(faces_path)
return ["None"] + faces
faces_path = os.path.join(
scripts.basedir(), "models", "faceswaplab", "faces", "*.pkl"
)
faces += glob.glob(faces_path)
return ["None"] + sorted(faces)