mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Mask typing and naming related updates
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .types import EmbedderModule, Embedding, Padding, WarpTemplate, WarpTemplateSet
|
||||
from .types import EmbedderModule, Embedding, Mask, Padding, WarpTemplate, WarpTemplateSet
|
||||
|
||||
WARP_TEMPLATE_SET : WarpTemplateSet =\
|
||||
{
|
||||
@@ -38,9 +38,9 @@ def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : P
|
||||
return embedding
|
||||
|
||||
|
||||
def overlay_mask(input_tensor : Tensor, mask_tensor : Tensor) -> Tensor:
|
||||
def overlay_mask(input_tensor : Tensor, input_mask : Mask) -> Tensor:
|
||||
overlay_tensor = torch.zeros(*input_tensor.shape, dtype = input_tensor.dtype, device = input_tensor.device)
|
||||
overlay_tensor[:, 2, :, :] = 1
|
||||
mask_tensor = mask_tensor.repeat(1, 3, 1, 1).clamp(0, 0.8)
|
||||
output_tensor = input_tensor * (1 - mask_tensor) + overlay_tensor * mask_tensor
|
||||
input_mask = input_mask.repeat(1, 3, 1, 1).clamp(0, 0.8)
|
||||
output_tensor = input_tensor * (1 - input_mask) + overlay_tensor * input_mask
|
||||
return output_tensor
|
||||
|
||||
@@ -7,7 +7,7 @@ from torch import Tensor, nn
|
||||
from torchvision import transforms
|
||||
|
||||
from ..helper import calc_embedding
|
||||
from ..types import Attribute, EmbedderModule, FaceParserModule, GazerModule, Loss, MotionExtractorModule
|
||||
from ..types import Attribute, EmbedderModule, FaceParserModule, GazerModule, Loss, Mask, MotionExtractorModule
|
||||
|
||||
|
||||
class DiscriminatorLoss(nn.Module):
|
||||
@@ -186,11 +186,11 @@ class MaskLoss(nn.Module):
|
||||
self.face_parser = face_parser
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def forward(self, target_tensor : Tensor, mask_tensor : Tensor) -> Loss:
|
||||
def forward(self, target_tensor : Tensor, output_mask : Mask) -> Loss:
|
||||
target_mask = self.calc_mask(target_tensor)
|
||||
target_mask = target_mask.view(-1, self.config_output_size, self.config_output_size)
|
||||
mask_tensor = mask_tensor.view(-1, self.config_output_size, self.config_output_size)
|
||||
mask_loss = self.mse_loss(target_mask, mask_tensor)
|
||||
output_mask = output_mask.view(-1, self.config_output_size, self.config_output_size)
|
||||
mask_loss = self.mse_loss(target_mask, output_mask)
|
||||
return mask_loss
|
||||
|
||||
def calc_mask(self, target_tensor : Tensor) -> Tensor:
|
||||
|
||||
@@ -3,7 +3,7 @@ from configparser import ConfigParser
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..types import Attribute
|
||||
from ..types import Attribute, Mask
|
||||
|
||||
|
||||
class MaskNet(nn.Module):
|
||||
@@ -14,7 +14,7 @@ class MaskNet(nn.Module):
|
||||
self.config_num_filters = config_parser.getint('training.model.masker', 'num_filters')
|
||||
self.down_samples = self.create_down_samples(self.config_input_channels, self.config_num_filters)
|
||||
self.up_samples = self.create_up_samples(self.config_num_filters)
|
||||
self.bottleneck = BottleNeck(self.config_num_filters * 2)
|
||||
self.bottleneck = BottleNeck(self.config_num_filters * 4)
|
||||
self.conv = nn.Conv2d(self.config_num_filters, self.config_output_channels, kernel_size = 1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
@@ -23,31 +23,33 @@ class MaskNet(nn.Module):
|
||||
return nn.ModuleList(
|
||||
[
|
||||
DownSample(input_channels, num_filters),
|
||||
DownSample(num_filters, num_filters * 2)
|
||||
DownSample(num_filters, num_filters * 2),
|
||||
DownSample(num_filters, num_filters * 4)
|
||||
])
|
||||
|
||||
@staticmethod
|
||||
def create_up_samples(num_filters : int) -> nn.ModuleList:
|
||||
return nn.ModuleList(
|
||||
[
|
||||
UpSample(num_filters * 4, num_filters),
|
||||
UpSample(num_filters * 2, num_filters),
|
||||
UpSample(num_filters, num_filters)
|
||||
])
|
||||
|
||||
def forward(self, input_tensor : Tensor, input_attribute : Attribute) -> Tensor:
|
||||
output_tensor = torch.cat([ input_tensor, input_attribute ], dim = 1)
|
||||
def forward(self, input_tensor : Tensor, input_attribute : Attribute) -> Mask:
|
||||
output_mask = torch.cat([ input_tensor, input_attribute ], dim = 1)
|
||||
|
||||
for down_sample in self.down_samples:
|
||||
output_tensor = down_sample(output_tensor)
|
||||
output_mask = down_sample(output_mask)
|
||||
|
||||
output_tensor = self.bottleneck(output_tensor)
|
||||
output_mask = self.bottleneck(output_mask)
|
||||
|
||||
for up_sample in self.up_samples:
|
||||
output_tensor = up_sample(output_tensor)
|
||||
output_mask = up_sample(output_mask)
|
||||
|
||||
output_tensor = self.conv(output_tensor)
|
||||
output_tensor = self.sigmoid(output_tensor)
|
||||
return output_tensor
|
||||
output_mask = self.conv(output_mask)
|
||||
output_mask = self.sigmoid(output_mask)
|
||||
return output_mask
|
||||
|
||||
|
||||
class BottleNeck(nn.Module):
|
||||
|
||||
@@ -18,7 +18,7 @@ from .models.discriminator import Discriminator
|
||||
from .models.generator import Generator
|
||||
from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, GazeLoss, IdentityLoss, MaskLoss, MotionLoss, ReconstructionLoss
|
||||
from .networks.masknet import MaskNet
|
||||
from .types import Batch, Embedding, OptimizerSet
|
||||
from .types import Batch, Embedding, Mask, OptimizerSet
|
||||
|
||||
warnings.filterwarnings('ignore', category = UserWarning, module = 'torch')
|
||||
|
||||
@@ -57,9 +57,9 @@ class FaceSwapperTrainer(LightningModule):
|
||||
with torch.no_grad():
|
||||
output_tensor, target_attributes = self.generator(source_embedding, target_tensor)
|
||||
target_attribute = target_attributes[-1]
|
||||
mask_tensor = self.masker(target_tensor, target_attribute)
|
||||
output_mask = self.masker(target_tensor, target_attribute)
|
||||
|
||||
return output_tensor, mask_tensor
|
||||
return output_tensor, output_mask
|
||||
|
||||
def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet, OptimizerSet]:
|
||||
generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
@@ -120,8 +120,8 @@ class FaceSwapperTrainer(LightningModule):
|
||||
discriminator_loss = self.discriminator_loss(discriminator_source_tensors, discriminator_output_tensors)
|
||||
|
||||
generator_output_attribute = generator_output_attributes[-1]
|
||||
mask_tensor = self.masker(generator_output_tensor.detach(), generator_output_attribute.detach())
|
||||
mask_loss = self.mask_loss(target_tensor, mask_tensor)
|
||||
generator_output_mask = self.masker(generator_output_tensor.detach(), generator_output_attribute.detach())
|
||||
mask_loss = self.mask_loss(target_tensor, generator_output_mask)
|
||||
|
||||
self.toggle_optimizer(generator_optimizer)
|
||||
self.manual_backward(generator_loss)
|
||||
@@ -145,7 +145,7 @@ class FaceSwapperTrainer(LightningModule):
|
||||
self.untoggle_optimizer(masker_optimizer)
|
||||
|
||||
if self.global_step % self.config_preview_frequency == 0:
|
||||
self.generate_preview(source_tensor, target_tensor, generator_output_tensor, mask_tensor)
|
||||
self.generate_preview(source_tensor, target_tensor, generator_output_tensor, generator_output_mask)
|
||||
|
||||
self.log('generator_loss', generator_loss, prog_bar = True)
|
||||
self.log('discriminator_loss', discriminator_loss, prog_bar = True)
|
||||
@@ -167,10 +167,10 @@ class FaceSwapperTrainer(LightningModule):
|
||||
self.log('validation_score', validation_score, prog_bar = True)
|
||||
return validation_score
|
||||
|
||||
def generate_preview(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor, mask_tensor : Tensor) -> None:
|
||||
def generate_preview(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor, output_mask : Mask) -> None:
|
||||
preview_limit = 8
|
||||
preview_cells = []
|
||||
overlay_tensor = overlay_mask(output_tensor, mask_tensor)
|
||||
overlay_tensor = overlay_mask(output_tensor, output_mask)
|
||||
|
||||
for source_tensor, target_tensor, output_tensor, overlay_tensor in zip(source_tensor[:preview_limit], target_tensor[:preview_limit], output_tensor[:preview_limit], overlay_tensor[:preview_limit]):
|
||||
preview_cell = torch.cat([ source_tensor, target_tensor, output_tensor, overlay_tensor ], dim = 2)
|
||||
|
||||
@@ -8,6 +8,7 @@ BatchMode = Literal['equal', 'same', 'different']
|
||||
|
||||
Attribute : TypeAlias = Tensor
|
||||
Embedding : TypeAlias = Tensor
|
||||
Mask : TypeAlias = Tensor
|
||||
Loss : TypeAlias = Tensor
|
||||
|
||||
Padding : TypeAlias = Tuple[int, int, int, int]
|
||||
|
||||
Reference in New Issue
Block a user