auto convert pkl
This commit is contained in:
@@ -8,6 +8,7 @@ import torch
|
||||
import modules.scripts as scripts
|
||||
from modules import scripts
|
||||
from scripts.faceswaplab_utils.faceswaplab_logging import logger
|
||||
import dill as pickle # will be removed in future versions
|
||||
|
||||
|
||||
def save_face(face: Face, filename: str) -> None:
|
||||
@@ -20,15 +21,32 @@ def save_face(face: Face, filename: str) -> None:
|
||||
|
||||
|
||||
def load_face(filename: str) -> Face:
|
||||
face = {}
|
||||
logger.debug("Try to load face from %s", filename)
|
||||
with safe_open(filename, framework="pt", device="cpu") as f:
|
||||
logger.debug("File contains %s keys", f.keys())
|
||||
for k in f.keys():
|
||||
logger.debug("load key %s", k)
|
||||
face[k] = f.get_tensor(k).numpy()
|
||||
logger.debug("face : %s", face)
|
||||
return Face(face)
|
||||
if filename.endswith(".pkl"):
|
||||
logger.warning(
|
||||
"Pkl files for faces are deprecated to enhance safety, they will be unsupported in future versions."
|
||||
)
|
||||
logger.warning("The file will be converted to .safetensors")
|
||||
logger.warning(
|
||||
"You can also use this script https://gist.github.com/glucauze/4a3c458541f2278ad801f6625e5b9d3d"
|
||||
)
|
||||
with open(filename, "rb") as file:
|
||||
logger.info("Load pkl")
|
||||
face = Face(pickle.load(file))
|
||||
logger.warning(
|
||||
"Convert to safetensors, you can remove the pkl version once you have ensured that the safetensor is working"
|
||||
)
|
||||
save_face(face, filename.replace(".pkl", ".safetensors"))
|
||||
return face
|
||||
|
||||
elif filename.endswith(".safetensors"):
|
||||
face = {}
|
||||
with safe_open(filename, framework="pt", device="cpu") as f:
|
||||
for k in f.keys():
|
||||
logger.debug("load key %s", k)
|
||||
face[k] = f.get_tensor(k).numpy()
|
||||
return Face(face)
|
||||
|
||||
raise NotImplementedError("Unknown file type, face extraction not implemented")
|
||||
|
||||
|
||||
def get_face_checkpoints() -> List[str]:
|
||||
@@ -45,4 +63,10 @@ def get_face_checkpoints() -> List[str]:
|
||||
scripts.basedir(), "models", "faceswaplab", "faces", "*.safetensors"
|
||||
)
|
||||
faces = glob.glob(faces_path)
|
||||
return ["None"] + faces
|
||||
|
||||
faces_path = os.path.join(
|
||||
scripts.basedir(), "models", "faceswaplab", "faces", "*.pkl"
|
||||
)
|
||||
faces += glob.glob(faces_path)
|
||||
|
||||
return ["None"] + sorted(faces)
|
||||
|
||||
Reference in New Issue
Block a user