mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
change mask preview
This commit is contained in:
@@ -36,3 +36,11 @@ def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : P
|
||||
embedding = embedder(crop_tensor)
|
||||
embedding = nn.functional.normalize(embedding, p = 2)
|
||||
return embedding
|
||||
|
||||
|
||||
def overlay_mask(target_tensor : Tensor, mask_tensor : Tensor) -> Tensor:
|
||||
color_tensor = torch.zeros(*list(target_tensor.shape), dtype = target_tensor.dtype, device = target_tensor.device)
|
||||
color_tensor[:, 2, :, :] = 1
|
||||
mask_tensor = mask_tensor.repeat(1, 3, 1, 1).clamp(0, 0.8)
|
||||
output_tensor = target_tensor * (1 - mask_tensor) + color_tensor * mask_tensor
|
||||
return output_tensor
|
||||
|
||||
@@ -13,7 +13,7 @@ from torch.utils.data import Dataset, random_split
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from .dataset import DynamicDataset
|
||||
from .helper import calc_embedding
|
||||
from .helper import calc_embedding, overlay_mask
|
||||
from .models.discriminator import Discriminator
|
||||
from .models.generator import Generator
|
||||
from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, GazeLoss, IdentityLoss, MaskLoss, MotionLoss, ReconstructionLoss
|
||||
@@ -172,7 +172,7 @@ class FaceSwapperTrainer(LightningModule):
|
||||
def generate_preview(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor, mask_tensor : Tensor) -> None:
|
||||
preview_limit = 8
|
||||
preview_cells = []
|
||||
mask_tensor = (mask_tensor.repeat(1, 3, 1, 1) - 0.5) * 2
|
||||
mask_tensor = overlay_mask(target_tensor, mask_tensor)
|
||||
|
||||
for source_tensor, target_tensor, output_tensor, mask_tensor in zip(source_tensor[:preview_limit], target_tensor[:preview_limit], output_tensor[:preview_limit], mask_tensor[:preview_limit]):
|
||||
preview_cell = torch.cat([ source_tensor, target_tensor, output_tensor, mask_tensor ], dim = 2)
|
||||
|
||||
Reference in New Issue
Block a user