improve test, add extract to api
This commit is contained in:
+36
-1
@@ -8,6 +8,7 @@ import base64, io
|
||||
from io import BytesIO
|
||||
from typing import List, Tuple, Optional
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
|
||||
class InpaintingWhen(Enum):
|
||||
@@ -151,7 +152,7 @@ class FaceSwapRequest(BaseModel):
|
||||
|
||||
class FaceSwapResponse(BaseModel):
|
||||
images: List[str] = Field(description="base64 swapped image", default=None)
|
||||
infos: List[str]
|
||||
infos: Optional[List[str]] # not really used atm
|
||||
|
||||
@property
|
||||
def pil_images(self) -> Image.Image:
|
||||
@@ -171,6 +172,23 @@ class FaceSwapCompareRequest(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class FaceSwapExtractRequest(BaseModel):
|
||||
images: List[str] = Field(
|
||||
description="base64 reference image",
|
||||
examples=["data:image/jpeg;base64,/9j/4AAQSkZJRgABAQECWAJYAAD...."],
|
||||
default=None,
|
||||
)
|
||||
postprocessing: Optional[PostProcessingOptions]
|
||||
|
||||
|
||||
class FaceSwapExtractResponse(BaseModel):
|
||||
images: List[str] = Field(description="base64 face images", default=None)
|
||||
|
||||
@property
|
||||
def pil_images(self) -> Image.Image:
|
||||
return [base64_to_pil(img) for img in self.images]
|
||||
|
||||
|
||||
def pil_to_base64(img: Image.Image) -> np.array: # type:ignore
|
||||
if isinstance(img, str):
|
||||
img = Image.open(img)
|
||||
@@ -192,3 +210,20 @@ def base64_to_pil(base64str: Optional[str]) -> Optional[Image.Image]:
|
||||
# if no data URL scheme, just decode
|
||||
img_bytes = base64.b64decode(base64str)
|
||||
return Image.open(io.BytesIO(img_bytes))
|
||||
|
||||
|
||||
def compare_faces(
|
||||
image1: Image.Image, image2: Image.Image, base_url: str = "http://localhost:7860"
|
||||
) -> float:
|
||||
request = FaceSwapCompareRequest(
|
||||
image1=pil_to_base64(image1),
|
||||
image2=pil_to_base64(image2),
|
||||
)
|
||||
|
||||
result = requests.post(
|
||||
url=f"{base_url}/faceswaplab/compare",
|
||||
data=request.json(),
|
||||
headers={"Content-Type": "application/json; charset=utf-8"},
|
||||
)
|
||||
|
||||
return float(result.text)
|
||||
|
||||
Reference in New Issue
Block a user