|
|
|
@@ -17,6 +17,7 @@ from .helper import calc_embedding
|
|
|
|
|
from .models.discriminator import Discriminator
|
|
|
|
|
from .models.generator import Generator
|
|
|
|
|
from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, GazeLoss, IdentityLoss, MaskLoss, MotionLoss, ReconstructionLoss
|
|
|
|
|
from .networks.masknet import MaskNet
|
|
|
|
|
from .types import Batch, Embedding, OptimizerSet
|
|
|
|
|
|
|
|
|
|
warnings.filterwarnings('ignore', category = UserWarning, module = 'torch')
|
|
|
|
@@ -40,6 +41,7 @@ class FaceSwapperTrainer(LightningModule):
|
|
|
|
|
self.parser = torch.jit.load(self.config_parser_path, map_location = 'cpu').eval()
|
|
|
|
|
self.generator = Generator(config_parser)
|
|
|
|
|
self.discriminator = Discriminator(config_parser)
|
|
|
|
|
self.masker = MaskNet(config_parser)
|
|
|
|
|
self.discriminator_loss = DiscriminatorLoss()
|
|
|
|
|
self.adversarial_loss = AdversarialLoss(config_parser)
|
|
|
|
|
self.attribute_loss = AttributeLoss(config_parser)
|
|
|
|
@@ -54,11 +56,13 @@ class FaceSwapperTrainer(LightningModule):
|
|
|
|
|
output_tensor, mask_tensor = self.generator(source_embedding, target_tensor)
|
|
|
|
|
return output_tensor, mask_tensor
|
|
|
|
|
|
|
|
|
|
def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet]:
|
|
|
|
|
def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet, OptimizerSet]:
|
|
|
|
|
generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
|
|
|
|
|
discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
|
|
|
|
|
masker_optimizer = torch.optim.AdamW(self.masker.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
|
|
|
|
|
generator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(generator_optimizer, T_0 = 300, T_mult = 2)
|
|
|
|
|
discriminator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(discriminator_optimizer, T_0 = 300, T_mult = 2)
|
|
|
|
|
masker_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(masker_optimizer, T_0 = 300, T_mult = 2)
|
|
|
|
|
|
|
|
|
|
generator_config =\
|
|
|
|
|
{
|
|
|
|
@@ -78,14 +82,23 @@ class FaceSwapperTrainer(LightningModule):
|
|
|
|
|
'interval': 'step'
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return generator_config, discriminator_config
|
|
|
|
|
masker_config =\
|
|
|
|
|
{
|
|
|
|
|
'optimizer': masker_optimizer,
|
|
|
|
|
'lr_scheduler':
|
|
|
|
|
{
|
|
|
|
|
'scheduler': masker_scheduler,
|
|
|
|
|
'interval': 'step'
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return generator_config, discriminator_config, masker_config
|
|
|
|
|
|
|
|
|
|
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
|
|
|
|
|
source_tensor, target_tensor = batch
|
|
|
|
|
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
|
|
|
|
|
generator_optimizer, discriminator_optimizer, masker_optimizer = self.optimizers() #type:ignore[attr-defined]
|
|
|
|
|
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
|
|
|
|
|
target_attributes = self.generator.get_attributes(target_tensor)
|
|
|
|
|
generator_output_tensor, generator_mask_tensor = self.generator(source_embedding, target_tensor)
|
|
|
|
|
generator_output_tensor = self.generator(source_embedding, target_tensor)
|
|
|
|
|
generator_output_attributes = self.generator.get_attributes(generator_output_tensor)
|
|
|
|
|
discriminator_output_tensors = self.discriminator(generator_output_tensor)
|
|
|
|
|
|
|
|
|
@@ -96,14 +109,22 @@ class FaceSwapperTrainer(LightningModule):
|
|
|
|
|
identity_loss, weighted_identity_loss = self.identity_loss(generator_output_tensor, source_tensor)
|
|
|
|
|
pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss = self.motion_loss(target_tensor, generator_output_tensor)
|
|
|
|
|
gaze_loss, weighted_gaze_loss = self.gaze_loss(target_tensor, generator_output_tensor)
|
|
|
|
|
mask_loss, weighted_mask_loss = self.mask_loss(target_tensor, generator_mask_tensor)
|
|
|
|
|
generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss + weighted_expression_loss + weighted_mask_loss
|
|
|
|
|
generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss + weighted_expression_loss
|
|
|
|
|
|
|
|
|
|
generator_optimizer.zero_grad()
|
|
|
|
|
self.manual_backward(generator_loss)
|
|
|
|
|
generator_optimizer.step()
|
|
|
|
|
self.untoggle_optimizer(generator_optimizer)
|
|
|
|
|
|
|
|
|
|
self.toggle_optimizer(masker_optimizer)
|
|
|
|
|
mask_tensor = self.masker(target_tensor, target_attributes[-1].detach())
|
|
|
|
|
mask_loss = self.mask_loss(target_tensor, mask_tensor)
|
|
|
|
|
|
|
|
|
|
masker_optimizer.zero_grad()
|
|
|
|
|
self.manual_backward(mask_loss)
|
|
|
|
|
masker_optimizer.step()
|
|
|
|
|
self.untoggle_optimizer(masker_optimizer)
|
|
|
|
|
|
|
|
|
|
self.toggle_optimizer(discriminator_optimizer)
|
|
|
|
|
discriminator_source_tensors = self.discriminator(source_tensor)
|
|
|
|
|
discriminator_output_tensors = self.discriminator(generator_output_tensor.detach())
|
|
|
|
@@ -131,7 +152,7 @@ class FaceSwapperTrainer(LightningModule):
|
|
|
|
|
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
|
|
|
|
|
source_tensor, target_tensor = batch
|
|
|
|
|
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
|
|
|
|
|
output_tensor, mask_tensor = self.generator(source_embedding, target_tensor)
|
|
|
|
|
output_tensor = self.generator(source_embedding, target_tensor)
|
|
|
|
|
output_embedding = calc_embedding(self.embedder, output_tensor, (0, 0, 0, 0))
|
|
|
|
|
validation_score = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5
|
|
|
|
|
self.log('validation_score', validation_score, prog_bar = True)
|
|
|
|
|