add tests
This commit is contained in:
@@ -13,9 +13,6 @@ from scripts.faceswaplab_ui.faceswaplab_unit_settings import FaceSwapUnitSetting
|
||||
from scripts.faceswaplab_utils.imgutils import (
|
||||
base64_to_pil,
|
||||
)
|
||||
from scripts.faceswaplab_utils.models_utils import get_current_model
|
||||
from modules.shared import opts
|
||||
from scripts.faceswaplab_postprocessing.postprocessing import enhance_image
|
||||
from scripts.faceswaplab_postprocessing.postprocessing_options import (
|
||||
PostProcessingOptions,
|
||||
)
|
||||
@@ -135,22 +132,18 @@ def faceswaplab_api(_: gr.Blocks, app: FastAPI) -> None:
|
||||
units: List[FaceSwapUnitSettings] = []
|
||||
src_image: Optional[Image.Image] = base64_to_pil(request.image)
|
||||
response = FaceSwapResponse(images=[], infos=[])
|
||||
if request.postprocessing:
|
||||
pp_options = get_postprocessing_options(request.postprocessing)
|
||||
|
||||
if src_image is not None:
|
||||
if request.postprocessing:
|
||||
pp_options = get_postprocessing_options(request.postprocessing)
|
||||
units = get_faceswap_units_settings(request.units)
|
||||
|
||||
swapped_images = swapper.process_images_units(
|
||||
get_current_model(),
|
||||
images=[(src_image, None)],
|
||||
units=units,
|
||||
upscaled_swapper=opts.data.get("faceswaplab_upscaled_swapper", False),
|
||||
swapped_images = swapper.batch_process(
|
||||
[src_image], None, units=units, postprocess_options=pp_options
|
||||
)
|
||||
for img, info in swapped_images:
|
||||
if pp_options:
|
||||
img = enhance_image(img, pp_options)
|
||||
response.images.append(encode_to_base64(img))
|
||||
response.infos.append(info)
|
||||
|
||||
for img in swapped_images:
|
||||
response.images.append(encode_to_base64(img))
|
||||
|
||||
response.infos = [] # Not used atm
|
||||
return response
|
||||
|
||||
@@ -8,7 +8,7 @@ REFERENCE_PATH = os.path.join(
|
||||
scripts.basedir(), "extensions", "sd-webui-faceswaplab", "references"
|
||||
)
|
||||
|
||||
VERSION_FLAG: str = "v1.1.0"
|
||||
VERSION_FLAG: str = "v1.1.1"
|
||||
EXTENSION_PATH = os.path.join("extensions", "sd-webui-faceswaplab")
|
||||
|
||||
# The NSFW score threshold. If any part of the image has a score greater than this threshold, the image will be considered NSFW.
|
||||
|
||||
@@ -2,6 +2,7 @@ import copy
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Set, Tuple, Optional
|
||||
import tempfile
|
||||
|
||||
import cv2
|
||||
import insightface
|
||||
@@ -21,6 +22,12 @@ from scripts import faceswaplab_globals
|
||||
from modules.shared import opts
|
||||
from functools import lru_cache
|
||||
from scripts.faceswaplab_ui.faceswaplab_unit_settings import FaceSwapUnitSettings
|
||||
from scripts.faceswaplab_postprocessing.postprocessing import enhance_image
|
||||
from scripts.faceswaplab_postprocessing.postprocessing_options import (
|
||||
PostProcessingOptions,
|
||||
)
|
||||
from scripts.faceswaplab_utils.models_utils import get_current_model
|
||||
|
||||
|
||||
providers = ["CPUExecutionProvider"]
|
||||
|
||||
@@ -78,6 +85,53 @@ def compare_faces(img1: Image.Image, img2: Image.Image) -> float:
|
||||
return -1
|
||||
|
||||
|
||||
def batch_process(
|
||||
src_images: List[Image.Image],
|
||||
save_path: Optional[str],
|
||||
units: List[FaceSwapUnitSettings],
|
||||
postprocess_options: PostProcessingOptions,
|
||||
) -> Optional[List[Image.Image]]:
|
||||
try:
|
||||
if save_path:
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
units = [u for u in units if u.enable]
|
||||
if src_images is not None and len(units) > 0:
|
||||
result_images = []
|
||||
for src_image in src_images:
|
||||
current_images = []
|
||||
swapped_images = process_images_units(
|
||||
get_current_model(),
|
||||
images=[(src_image, None)],
|
||||
units=units,
|
||||
upscaled_swapper=opts.data.get(
|
||||
"faceswaplab_upscaled_swapper", False
|
||||
),
|
||||
)
|
||||
if len(swapped_images) > 0:
|
||||
current_images += [img for img, _ in swapped_images]
|
||||
|
||||
logger.info("%s images generated", len(current_images))
|
||||
for i, img in enumerate(current_images):
|
||||
current_images[i] = enhance_image(img, postprocess_options)
|
||||
|
||||
if save_path:
|
||||
for img in current_images:
|
||||
path = tempfile.NamedTemporaryFile(
|
||||
delete=False, suffix=".png", dir=save_path
|
||||
).name
|
||||
img.save(path)
|
||||
|
||||
result_images += current_images
|
||||
return result_images
|
||||
except Exception as e:
|
||||
logger.error("Batch Process error : %s", e)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
class FaceModelException(Exception):
|
||||
"""Exception raised when an error is encountered in the face model."""
|
||||
|
||||
|
||||
@@ -26,7 +26,6 @@ from scripts.faceswaplab_postprocessing.postprocessing import enhance_image
|
||||
from dataclasses import fields
|
||||
from typing import Any, Dict, List, Optional
|
||||
from scripts.faceswaplab_ui.faceswaplab_unit_settings import FaceSwapUnitSettings
|
||||
from scripts.faceswaplab_utils.models_utils import get_current_model
|
||||
import re
|
||||
|
||||
|
||||
@@ -291,9 +290,6 @@ def batch_process(
|
||||
files: List[gr.File], save_path: str, *components: List[gr.components.Component]
|
||||
) -> Optional[List[Image.Image]]:
|
||||
try:
|
||||
if save_path is not None:
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
units_count = opts.data.get("faceswaplab_units_count", 3)
|
||||
units: List[FaceSwapUnitSettings] = []
|
||||
|
||||
@@ -312,36 +308,15 @@ def batch_process(
|
||||
*components[shift : shift + len(fields(PostProcessingOptions))] # type: ignore
|
||||
)
|
||||
logger.debug("%s", pformat(postprocess_options))
|
||||
|
||||
units = [u for u in units if u.enable]
|
||||
if files is not None and len(units) > 0:
|
||||
images = []
|
||||
for file in files:
|
||||
current_images = []
|
||||
src_image = Image.open(file.name)
|
||||
swapped_images = swapper.process_images_units(
|
||||
get_current_model(),
|
||||
images=[(src_image, None)],
|
||||
units=units,
|
||||
upscaled_swapper=opts.data.get(
|
||||
"faceswaplab_upscaled_swapper", False
|
||||
),
|
||||
)
|
||||
if len(swapped_images) > 0:
|
||||
current_images += [img for img, _ in swapped_images]
|
||||
|
||||
logger.info("%s images generated", len(current_images))
|
||||
for i, img in enumerate(current_images):
|
||||
current_images[i] = enhance_image(img, postprocess_options)
|
||||
|
||||
for img in current_images:
|
||||
path = tempfile.NamedTemporaryFile(
|
||||
delete=False, suffix=".png", dir=save_path
|
||||
).name
|
||||
img.save(path)
|
||||
|
||||
images += current_images
|
||||
return images
|
||||
images = [
|
||||
Image.open(file.name) for file in files
|
||||
] # potentially greedy but Image.open is supposed to be lazy
|
||||
return swapper.batch_process(
|
||||
images,
|
||||
save_path=save_path,
|
||||
units=units,
|
||||
postprocess_options=postprocess_options,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Batch Process error : %s", e)
|
||||
import traceback
|
||||
|
||||
Reference in New Issue
Block a user