Compare commits
1 Commits
512_beta
...
andreasjan
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4dc201d647 |
13
predict.py
13
predict.py
@@ -1,3 +1,5 @@
|
||||
import os
|
||||
from glob import glob
|
||||
import cog
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
@@ -16,6 +18,12 @@ from insightface_func.face_detect_crop_multi import Face_detect_crop as Face_det
|
||||
from insightface_func.face_detect_crop_single import Face_detect_crop as Face_detect_crop_single
|
||||
|
||||
|
||||
TARGET_OPTIONS_DIR = "cog-targets"
|
||||
|
||||
def list_target_options():
|
||||
return [os.path.basename(p).split(".")[0] for p in glob(f"{TARGET_OPTIONS_DIR}/*.jpg")]
|
||||
|
||||
|
||||
class Predictor(cog.Predictor):
|
||||
def setup(self):
|
||||
self.transformer_Arcface = transforms.Compose([
|
||||
@@ -24,10 +32,11 @@ class Predictor(cog.Predictor):
|
||||
])
|
||||
|
||||
@cog.input("source", type=Path, help="source image")
|
||||
@cog.input("target", type=Path, help="target image")
|
||||
@cog.input("target", type=str, options=list_target_options(), help="target image")
|
||||
@cog.input("mode", type=str, options=['single', 'all'], default='all',
|
||||
help="swap a single face (the one with highest confidence by face detection) or all faces in the target image")
|
||||
def predict(self, source, target, mode='all'):
|
||||
target_path = f"{TARGET_OPTIONS_DIR}/{target}.jpg"
|
||||
|
||||
app = Face_detect_crop_multi(name='antelope', root='./insightface_func/models')
|
||||
|
||||
@@ -39,7 +48,7 @@ class Predictor(cog.Predictor):
|
||||
options = TestOptions()
|
||||
options.initialize()
|
||||
opt = options.parser.parse_args(["--Arc_path", 'arcface_model/arcface_checkpoint.tar', "--pic_a_path", str(source),
|
||||
"--pic_b_path", str(target), "--isTrain", False, "--no_simswaplogo"])
|
||||
"--pic_b_path", target_path, "--isTrain", False, "--no_simswaplogo"])
|
||||
|
||||
str_ids = opt.gpu_ids.split(',')
|
||||
opt.gpu_ids = []
|
||||
|
||||
Reference in New Issue
Block a user