This commit is contained in:
harisreedhar
2025-03-10 13:43:52 +05:30
committed by henryruhs
parent 1659805b08
commit a3ac4d5ddd
5 changed files with 32 additions and 20 deletions
-1
View File
@@ -81,7 +81,6 @@ identity_weight = 20.0
gaze_weight = 0.0
pose_weight = 0.0
expression_weight = 0.0
mask_weight = 1.0
```
```
-1
View File
@@ -41,7 +41,6 @@ identity_weight =
gaze_weight =
pose_weight =
expression_weight =
mask_weight =
[training.trainer]
learning_rate =
+2 -7
View File
@@ -1,10 +1,8 @@
from configparser import ConfigParser
from typing import Tuple
from torch import Tensor, nn
from ..networks.aad import AAD
from ..networks.masknet import MaskNet
from ..networks.unet import UNet
from ..types import Attributes, Embedding
@@ -14,16 +12,13 @@ class Generator(nn.Module):
super().__init__()
self.encoder = UNet(config_parser)
self.generator = AAD(config_parser)
self.masker = MaskNet(config_parser)
self.encoder.apply(init_weight)
self.generator.apply(init_weight)
self.masker.apply(init_weight)
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Tensor]:
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tensor:
target_attributes = self.get_attributes(target_tensor)
output_tensor = self.generator(source_embedding, target_attributes)
mask_tensor = self.masker(target_tensor, target_attributes[-1])
return output_tensor, mask_tensor
return output_tensor
def get_attributes(self, input_tensor : Tensor) -> Attributes:
return self.encoder(input_tensor)
+2 -4
View File
@@ -180,18 +180,16 @@ class GazeLoss(nn.Module):
class MaskLoss(nn.Module):
def __init__(self, config_parser : ConfigParser, parser : ParserModule) -> None:
super().__init__()
self.config_mask_weight = config_parser.getfloat('training.losses', 'mask_weight')
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
self.parser = parser
self.mse_loss = nn.MSELoss()
def forward(self, target_tensor : Tensor, mask_tensor : Tensor) -> Tuple[Tensor, Tensor]:
def forward(self, target_tensor : Tensor, mask_tensor : Tensor) -> Tensor:
target_mask = self.calc_mask(target_tensor)
target_mask = target_mask.view(-1, self.config_output_size, self.config_output_size)
mask_tensor = mask_tensor.view(-1, self.config_output_size, self.config_output_size)
mask_loss = self.mse_loss(target_mask, mask_tensor)
weighted_mask_loss = mask_loss * self.config_mask_weight
return mask_loss, weighted_mask_loss
return mask_loss
def calc_mask(self, target_tensor : Tensor) -> Tensor:
target_tensor = torch.nn.functional.interpolate(target_tensor, (512, 512), mode = 'bilinear')
+28 -7
View File
@@ -17,6 +17,7 @@ from .helper import calc_embedding
from .models.discriminator import Discriminator
from .models.generator import Generator
from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, GazeLoss, IdentityLoss, MaskLoss, MotionLoss, ReconstructionLoss
from .networks.masknet import MaskNet
from .types import Batch, Embedding, OptimizerSet
warnings.filterwarnings('ignore', category = UserWarning, module = 'torch')
@@ -40,6 +41,7 @@ class FaceSwapperTrainer(LightningModule):
self.parser = torch.jit.load(self.config_parser_path, map_location = 'cpu').eval()
self.generator = Generator(config_parser)
self.discriminator = Discriminator(config_parser)
self.masker = MaskNet(config_parser)
self.discriminator_loss = DiscriminatorLoss()
self.adversarial_loss = AdversarialLoss(config_parser)
self.attribute_loss = AttributeLoss(config_parser)
@@ -54,11 +56,13 @@ class FaceSwapperTrainer(LightningModule):
output_tensor, mask_tensor = self.generator(source_embedding, target_tensor)
return output_tensor, mask_tensor
def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet]:
def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet, OptimizerSet]:
generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
masker_optimizer = torch.optim.AdamW(self.masker.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
generator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(generator_optimizer, T_0 = 300, T_mult = 2)
discriminator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(discriminator_optimizer, T_0 = 300, T_mult = 2)
masker_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(masker_optimizer, T_0 = 300, T_mult = 2)
generator_config =\
{
@@ -78,14 +82,23 @@ class FaceSwapperTrainer(LightningModule):
'interval': 'step'
}
}
return generator_config, discriminator_config
masker_config =\
{
'optimizer': masker_optimizer,
'lr_scheduler':
{
'scheduler': masker_scheduler,
'interval': 'step'
}
}
return generator_config, discriminator_config, masker_config
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
source_tensor, target_tensor = batch
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
generator_optimizer, discriminator_optimizer, masker_optimizer = self.optimizers() #type:ignore[attr-defined]
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
target_attributes = self.generator.get_attributes(target_tensor)
generator_output_tensor, generator_mask_tensor = self.generator(source_embedding, target_tensor)
generator_output_tensor = self.generator(source_embedding, target_tensor)
generator_output_attributes = self.generator.get_attributes(generator_output_tensor)
discriminator_output_tensors = self.discriminator(generator_output_tensor)
@@ -96,14 +109,22 @@ class FaceSwapperTrainer(LightningModule):
identity_loss, weighted_identity_loss = self.identity_loss(generator_output_tensor, source_tensor)
pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss = self.motion_loss(target_tensor, generator_output_tensor)
gaze_loss, weighted_gaze_loss = self.gaze_loss(target_tensor, generator_output_tensor)
mask_loss, weighted_mask_loss = self.mask_loss(target_tensor, generator_mask_tensor)
generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss + weighted_expression_loss + weighted_mask_loss
generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss + weighted_expression_loss
generator_optimizer.zero_grad()
self.manual_backward(generator_loss)
generator_optimizer.step()
self.untoggle_optimizer(generator_optimizer)
self.toggle_optimizer(masker_optimizer)
mask_tensor = self.masker(target_tensor, target_attributes[-1].detach())
mask_loss = self.mask_loss(target_tensor, mask_tensor)
masker_optimizer.zero_grad()
self.manual_backward(mask_loss)
masker_optimizer.step()
self.untoggle_optimizer(masker_optimizer)
self.toggle_optimizer(discriminator_optimizer)
discriminator_source_tensors = self.discriminator(source_tensor)
discriminator_output_tensors = self.discriminator(generator_output_tensor.detach())
@@ -131,7 +152,7 @@ class FaceSwapperTrainer(LightningModule):
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
source_tensor, target_tensor = batch
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
output_tensor, mask_tensor = self.generator(source_embedding, target_tensor)
output_tensor = self.generator(source_embedding, target_tensor)
output_embedding = calc_embedding(self.embedder, output_tensor, (0, 0, 0, 0))
validation_score = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5
self.log('validation_score', validation_score, prog_bar = True)