mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
Join MaskNet to guide generator
This commit is contained in:
@@ -82,6 +82,7 @@ identity_weight = 20.0
|
||||
gaze_weight = 0.05
|
||||
pose_weight = 0.05
|
||||
expression_weight = 0.05
|
||||
mask_weight = 0.5
|
||||
```
|
||||
|
||||
```
|
||||
|
||||
@@ -42,6 +42,7 @@ identity_weight =
|
||||
gaze_weight =
|
||||
pose_weight =
|
||||
expression_weight =
|
||||
mask_weight =
|
||||
|
||||
[training.trainer]
|
||||
accumulate_size =
|
||||
|
||||
@@ -39,6 +39,7 @@ def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : P
|
||||
|
||||
|
||||
def overlay_mask(input_tensor : Tensor, input_mask : Mask) -> Tensor:
|
||||
input_mask = input_mask.mean(dim = 1, keepdim = True)
|
||||
overlay_tensor = torch.zeros(*input_tensor.shape, dtype = input_tensor.dtype, device = input_tensor.device)
|
||||
overlay_tensor[:, 2, :, :] = 1
|
||||
input_mask = input_mask.repeat(1, 3, 1, 1).clamp(0, 0.8)
|
||||
|
||||
@@ -26,5 +26,5 @@ def infer() -> None:
|
||||
source_tensor = io.read_image(config_source_path)
|
||||
target_tensor = io.read_image(config_target_path)
|
||||
source_embedding = calc_embedding(embedder, source_tensor, (0, 0, 0, 0))
|
||||
output_tensor = generator(source_embedding, target_tensor)[0]
|
||||
output_tensor, _ = generator(source_embedding, target_tensor)
|
||||
io.write_jpeg(output_tensor, config_output_path)
|
||||
|
||||
@@ -5,7 +5,8 @@ from torch import Tensor, nn
|
||||
|
||||
from ..networks.aad import AAD
|
||||
from ..networks.unet import UNet
|
||||
from ..types import Embedding, Feature
|
||||
from ..networks.masknet import MaskNet
|
||||
from ..types import Embedding, Feature, Mask
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
@@ -13,13 +14,17 @@ class Generator(nn.Module):
|
||||
super().__init__()
|
||||
self.encoder = UNet(config_parser)
|
||||
self.generator = AAD(config_parser)
|
||||
self.masker = MaskNet(config_parser)
|
||||
self.encoder.apply(init_weight)
|
||||
self.generator.apply(init_weight)
|
||||
self.masker.apply(init_weight)
|
||||
|
||||
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Tuple[Feature, ...]]:
|
||||
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Mask]:
|
||||
target_features = self.encode_features(target_tensor)
|
||||
output_tensor = self.generator(source_embedding, target_features)
|
||||
return output_tensor, target_features
|
||||
target_feature = target_features[-1]
|
||||
output_mask = self.masker(target_tensor, target_feature)
|
||||
return output_tensor, output_mask
|
||||
|
||||
def encode_features(self, input_tensor : Tensor) -> Tuple[Feature, ...]:
|
||||
return self.encoder(input_tensor)
|
||||
|
||||
@@ -182,16 +182,18 @@ class GazeLoss(nn.Module):
|
||||
class MaskLoss(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser, face_parser : FaceParserModule) -> None:
|
||||
super().__init__()
|
||||
self.config_mask_weight = config_parser.getfloat('training.losses', 'mask_weight')
|
||||
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
|
||||
self.face_parser = face_parser
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def forward(self, target_tensor : Tensor, output_mask : Mask) -> Loss:
|
||||
def forward(self, target_tensor : Tensor, output_mask : Mask) -> Tuple[Loss, Loss]:
|
||||
target_mask = self.calc_mask(target_tensor)
|
||||
target_mask = target_mask.view(-1, self.config_output_size, self.config_output_size)
|
||||
output_mask = output_mask.view(-1, self.config_output_size, self.config_output_size)
|
||||
mask_loss = self.mse_loss(target_mask, output_mask)
|
||||
return mask_loss
|
||||
weighted_mask_loss = mask_loss * self.config_mask_weight
|
||||
return mask_loss, weighted_mask_loss
|
||||
|
||||
def calc_mask(self, target_tensor : Tensor) -> Tensor:
|
||||
target_tensor = torch.nn.functional.interpolate(target_tensor, (512, 512), mode = 'bilinear')
|
||||
|
||||
@@ -17,7 +17,6 @@ from .helper import calc_embedding, overlay_mask
|
||||
from .models.discriminator import Discriminator
|
||||
from .models.generator import Generator
|
||||
from .models.loss import AdversarialLoss, DiscriminatorLoss, FeautureLoss, GazeLoss, IdentityLoss, MaskLoss, MotionLoss, ReconstructionLoss
|
||||
from .networks.masknet import MaskNet
|
||||
from .types import Batch, Embedding, Mask, OptimizerSet
|
||||
|
||||
warnings.filterwarnings('ignore', category = UserWarning, module = 'torch')
|
||||
@@ -42,7 +41,6 @@ class FaceSwapperTrainer(LightningModule):
|
||||
self.face_parser = torch.jit.load(self.config_face_parser_path, map_location ='cpu').eval()
|
||||
self.generator = Generator(config_parser)
|
||||
self.discriminator = Discriminator(config_parser)
|
||||
self.masker = MaskNet(config_parser)
|
||||
self.discriminator_loss = DiscriminatorLoss()
|
||||
self.adversarial_loss = AdversarialLoss(config_parser)
|
||||
self.feature_loss = FeautureLoss(config_parser)
|
||||
@@ -55,19 +53,15 @@ class FaceSwapperTrainer(LightningModule):
|
||||
|
||||
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Mask]:
|
||||
with torch.no_grad():
|
||||
output_tensor, target_features = self.generator(source_embedding, target_tensor)
|
||||
target_feature = target_features[-1]
|
||||
output_mask = self.masker(target_tensor, target_feature)
|
||||
output_tensor, output_mask = self.generator(source_embedding, target_tensor)
|
||||
|
||||
return output_tensor, output_mask
|
||||
|
||||
def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet, OptimizerSet]:
|
||||
def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet]:
|
||||
generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
masker_optimizer = torch.optim.AdamW(self.masker.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
generator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(generator_optimizer, T_0 = 300, T_mult = 2)
|
||||
discriminator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(discriminator_optimizer, T_0 = 300, T_mult = 2)
|
||||
masker_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(masker_optimizer, T_0 = 300, T_mult = 2)
|
||||
|
||||
generator_config =\
|
||||
{
|
||||
@@ -87,24 +81,16 @@ class FaceSwapperTrainer(LightningModule):
|
||||
'interval': 'step'
|
||||
}
|
||||
}
|
||||
masker_config =\
|
||||
{
|
||||
'optimizer': masker_optimizer,
|
||||
'lr_scheduler':
|
||||
{
|
||||
'scheduler': masker_scheduler,
|
||||
'interval': 'step'
|
||||
}
|
||||
}
|
||||
return generator_config, discriminator_config, masker_config
|
||||
return generator_config, discriminator_config
|
||||
|
||||
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
source_tensor, target_tensor = batch
|
||||
do_update = (batch_index + 1) % self.config_accumulate_size == 0
|
||||
generator_optimizer, discriminator_optimizer, masker_optimizer = self.optimizers() #type:ignore[attr-defined]
|
||||
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
|
||||
|
||||
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
|
||||
generator_output_tensor, generator_target_features = self.generator(source_embedding, target_tensor)
|
||||
generator_output_tensor, generator_output_mask = self.generator(source_embedding, target_tensor)
|
||||
generator_target_features = self.generator.encode_features(target_tensor)
|
||||
generator_output_features = self.generator.encode_features(generator_output_tensor)
|
||||
discriminator_output_tensors = self.discriminator(generator_output_tensor)
|
||||
adversarial_loss, weighted_adversarial_loss = self.adversarial_loss(discriminator_output_tensors)
|
||||
@@ -113,16 +99,13 @@ class FaceSwapperTrainer(LightningModule):
|
||||
identity_loss, weighted_identity_loss = self.identity_loss(generator_output_tensor, source_tensor)
|
||||
pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss = self.motion_loss(target_tensor, generator_output_tensor)
|
||||
gaze_loss, weighted_gaze_loss = self.gaze_loss(target_tensor, generator_output_tensor)
|
||||
generator_loss = weighted_adversarial_loss + weighted_feature_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss + weighted_expression_loss
|
||||
mask_loss, weighted_mask_loss = self.mask_loss(target_tensor, generator_output_mask)
|
||||
generator_loss = weighted_adversarial_loss + weighted_feature_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss + weighted_expression_loss + weighted_mask_loss
|
||||
|
||||
discriminator_source_tensors = self.discriminator(source_tensor)
|
||||
discriminator_output_tensors = self.discriminator(generator_output_tensor.detach())
|
||||
discriminator_loss = self.discriminator_loss(discriminator_source_tensors, discriminator_output_tensors)
|
||||
|
||||
generator_output_feature = generator_output_features[-1]
|
||||
generator_output_mask = self.masker(generator_output_tensor.detach(), generator_output_feature.detach())
|
||||
mask_loss = self.mask_loss(target_tensor, generator_output_mask)
|
||||
|
||||
self.toggle_optimizer(generator_optimizer)
|
||||
self.manual_backward(generator_loss)
|
||||
if do_update:
|
||||
@@ -137,13 +120,6 @@ class FaceSwapperTrainer(LightningModule):
|
||||
discriminator_optimizer.zero_grad()
|
||||
self.untoggle_optimizer(discriminator_optimizer)
|
||||
|
||||
self.toggle_optimizer(masker_optimizer)
|
||||
self.manual_backward(mask_loss)
|
||||
if do_update:
|
||||
masker_optimizer.step()
|
||||
masker_optimizer.zero_grad()
|
||||
self.untoggle_optimizer(masker_optimizer)
|
||||
|
||||
if self.global_step % self.config_preview_frequency == 0:
|
||||
self.generate_preview(source_tensor, target_tensor, generator_output_tensor, generator_output_mask)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user