experimental gpu option

This commit is contained in:
Tran Xen
2023-08-04 14:23:50 +02:00
parent 7282abf9a9
commit 9bb1c65db6
6 changed files with 61 additions and 18 deletions
@@ -37,11 +37,22 @@ def on_ui_settings() -> None:
),
)
shared.opts.add_option(
"faceswaplab_det_size",
shared.OptionInfo(
640,
"det_size : Size of the detection area for face analysis. Higher values may improve quality but reduce speed. Low value may improve detection of very large face.",
gr.Slider,
{"minimum": 320, "maximum": 640, "step": 320},
section=section,
),
)
shared.opts.add_option(
"faceswaplab_detection_threshold",
shared.OptionInfo(
0.5,
"Face Detection threshold",
"det_thresh : Face Detection threshold",
gr.Slider,
{"minimum": 0.1, "maximum": 0.99, "step": 0.001},
section=section,
+23 -15
View File
@@ -37,8 +37,19 @@ from scripts.faceswaplab_postprocessing.postprocessing_options import (
from scripts.faceswaplab_utils.models_utils import get_current_model
from scripts.faceswaplab_utils.typing import CV2ImgU8, PILImage, Face
from scripts.faceswaplab_inpainting.i2i_pp import img2img_diffusion
from modules import shared
providers = ["CPUExecutionProvider"]
use_gpu = getattr(shared.cmd_opts, "faceswaplab_gpu", False)
if use_gpu and sys.platform != "darwin":
providers = [
"TensorrtExecutionProvider",
"CUDAExecutionProvider",
"CPUExecutionProvider",
]
else:
providers = ["CPUExecutionProvider"]
def cosine_similarity_face(face1: Face, face2: Face) -> float:
@@ -258,7 +269,9 @@ def capture_stdout() -> Generator[StringIO, None, None]:
@lru_cache(maxsize=1)
def getAnalysisModel() -> insightface.app.FaceAnalysis:
def getAnalysisModel(
det_size: Tuple[int, int] = (640, 640), det_thresh: float = 0.5
) -> insightface.app.FaceAnalysis:
"""
Retrieves the analysis model for face analysis.
@@ -269,7 +282,9 @@ def getAnalysisModel() -> insightface.app.FaceAnalysis:
if not os.path.exists(faceswaplab_globals.ANALYZER_DIR):
os.makedirs(faceswaplab_globals.ANALYZER_DIR)
logger.info("Load analysis model, will take some time. (> 30s)")
logger.info(
f"Load analysis model det_size={det_size}, det_thresh={det_thresh}, will take some time. (> 30s)"
)
# Initialize the analysis model with the specified name and providers
with tqdm(
@@ -281,6 +296,9 @@ def getAnalysisModel() -> insightface.app.FaceAnalysis:
providers=providers,
root=faceswaplab_globals.ANALYZER_DIR,
)
# Prepare the analysis model for face detection with the specified detection size
model.prepare(ctx_id=0, det_thresh=det_thresh, det_size=det_size)
pbar.update(1)
logger.info("%s", pformat(captured.getvalue()))
@@ -350,7 +368,6 @@ def getFaceSwapModel(model_path: str) -> upscaled_inswapper.UpscaledINSwapper:
def get_faces(
img_data: CV2ImgU8,
det_size: Tuple[int, int] = (640, 640),
det_thresh: Optional[float] = None,
) -> List[Face]:
"""
@@ -368,21 +385,12 @@ def get_faces(
if det_thresh is None:
det_thresh = opts.data.get("faceswaplab_detection_threshold", 0.5)
# Create a deep copy of the analysis model (otherwise det_size is attached to the analysis model and can't be changed)
face_analyser = copy.deepcopy(getAnalysisModel())
# Prepare the analysis model for face detection with the specified detection size
face_analyser.prepare(ctx_id=0, det_thresh=det_thresh, det_size=det_size)
det_size = opts.data.get("faceswaplab_det_size", 640)
face_analyser = getAnalysisModel((det_size, det_size), det_thresh)
# Get the detected faces from the image using the analysis model
face = face_analyser.get(img_data)
# If no faces are detected and the detection size is larger than 320x320,
# recursively call the function with a smaller detection size
if len(face) == 0 and det_size[0] > 320 and det_size[1] > 320:
det_size_half = (det_size[0] // 2, det_size[1] // 2)
return get_faces(img_data, det_size=det_size_half, det_thresh=det_thresh)
try:
# Sort the detected faces based on their x-coordinate of the bounding box
return sorted(face, key=lambda x: x.bbox[0])
@@ -262,7 +262,6 @@ class UpscaledINSwapper(INSwapper):
)
img_white[img_white > 20] = 255
fthresh = 10
print("fthresh", fthresh)
fake_diff[fake_diff < fthresh] = 0
fake_diff[fake_diff >= fthresh] = 255
img_mask = img_white