diff --git a/face_swapper/src/helper.py b/face_swapper/src/helper.py index 8ea5b3a..7fc5e9e 100644 --- a/face_swapper/src/helper.py +++ b/face_swapper/src/helper.py @@ -36,3 +36,11 @@ def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : P embedding = embedder(crop_tensor) embedding = nn.functional.normalize(embedding, p = 2) return embedding + + +def overlay_mask(target_tensor : Tensor, mask_tensor : Tensor) -> Tensor: + color_tensor = torch.zeros(*list(target_tensor.shape), dtype = target_tensor.dtype, device = target_tensor.device) + color_tensor[:, 2, :, :] = 1 + mask_tensor = mask_tensor.repeat(1, 3, 1, 1).clamp(0, 0.8) + output_tensor = target_tensor * (1 - mask_tensor) + color_tensor * mask_tensor + return output_tensor diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 7040725..7bda57b 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -13,7 +13,7 @@ from torch.utils.data import Dataset, random_split from torchdata.stateful_dataloader import StatefulDataLoader from .dataset import DynamicDataset -from .helper import calc_embedding +from .helper import calc_embedding, overlay_mask from .models.discriminator import Discriminator from .models.generator import Generator from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, GazeLoss, IdentityLoss, MaskLoss, MotionLoss, ReconstructionLoss @@ -172,7 +172,7 @@ class FaceSwapperTrainer(LightningModule): def generate_preview(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor, mask_tensor : Tensor) -> None: preview_limit = 8 preview_cells = [] - mask_tensor = (mask_tensor.repeat(1, 3, 1, 1) - 0.5) * 2 + mask_tensor = overlay_mask(target_tensor, mask_tensor) for source_tensor, target_tensor, output_tensor, mask_tensor in zip(source_tensor[:preview_limit], target_tensor[:preview_limit], output_tensor[:preview_limit], mask_tensor[:preview_limit]): preview_cell = torch.cat([ source_tensor, target_tensor, output_tensor, mask_tensor ], dim = 2)