Join MaskNet to guide generator

This commit is contained in:
henryruhs
2025-03-16 12:23:08 +01:00
parent 803902c8bb
commit ad675ae633
7 changed files with 24 additions and 38 deletions
+1
View File
@@ -82,6 +82,7 @@ identity_weight = 20.0
gaze_weight = 0.05
pose_weight = 0.05
expression_weight = 0.05
mask_weight = 0.5
```
```
+1
View File
@@ -42,6 +42,7 @@ identity_weight =
gaze_weight =
pose_weight =
expression_weight =
mask_weight =
[training.trainer]
accumulate_size =
+1
View File
@@ -39,6 +39,7 @@ def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : P
def overlay_mask(input_tensor : Tensor, input_mask : Mask) -> Tensor:
input_mask = input_mask.mean(dim = 1, keepdim = True)
overlay_tensor = torch.zeros(*input_tensor.shape, dtype = input_tensor.dtype, device = input_tensor.device)
overlay_tensor[:, 2, :, :] = 1
input_mask = input_mask.repeat(1, 3, 1, 1).clamp(0, 0.8)
+1 -1
View File
@@ -26,5 +26,5 @@ def infer() -> None:
source_tensor = io.read_image(config_source_path)
target_tensor = io.read_image(config_target_path)
source_embedding = calc_embedding(embedder, source_tensor, (0, 0, 0, 0))
output_tensor = generator(source_embedding, target_tensor)[0]
output_tensor, _ = generator(source_embedding, target_tensor)
io.write_jpeg(output_tensor, config_output_path)
+8 -3
View File
@@ -5,7 +5,8 @@ from torch import Tensor, nn
from ..networks.aad import AAD
from ..networks.unet import UNet
from ..types import Embedding, Feature
from ..networks.masknet import MaskNet
from ..types import Embedding, Feature, Mask
class Generator(nn.Module):
@@ -13,13 +14,17 @@ class Generator(nn.Module):
super().__init__()
self.encoder = UNet(config_parser)
self.generator = AAD(config_parser)
self.masker = MaskNet(config_parser)
self.encoder.apply(init_weight)
self.generator.apply(init_weight)
self.masker.apply(init_weight)
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Tuple[Feature, ...]]:
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Mask]:
target_features = self.encode_features(target_tensor)
output_tensor = self.generator(source_embedding, target_features)
return output_tensor, target_features
target_feature = target_features[-1]
output_mask = self.masker(target_tensor, target_feature)
return output_tensor, output_mask
def encode_features(self, input_tensor : Tensor) -> Tuple[Feature, ...]:
return self.encoder(input_tensor)
+4 -2
View File
@@ -182,16 +182,18 @@ class GazeLoss(nn.Module):
class MaskLoss(nn.Module):
def __init__(self, config_parser : ConfigParser, face_parser : FaceParserModule) -> None:
super().__init__()
self.config_mask_weight = config_parser.getfloat('training.losses', 'mask_weight')
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
self.face_parser = face_parser
self.mse_loss = nn.MSELoss()
def forward(self, target_tensor : Tensor, output_mask : Mask) -> Loss:
def forward(self, target_tensor : Tensor, output_mask : Mask) -> Tuple[Loss, Loss]:
target_mask = self.calc_mask(target_tensor)
target_mask = target_mask.view(-1, self.config_output_size, self.config_output_size)
output_mask = output_mask.view(-1, self.config_output_size, self.config_output_size)
mask_loss = self.mse_loss(target_mask, output_mask)
return mask_loss
weighted_mask_loss = mask_loss * self.config_mask_weight
return mask_loss, weighted_mask_loss
def calc_mask(self, target_tensor : Tensor) -> Tensor:
target_tensor = torch.nn.functional.interpolate(target_tensor, (512, 512), mode = 'bilinear')
+8 -32
View File
@@ -17,7 +17,6 @@ from .helper import calc_embedding, overlay_mask
from .models.discriminator import Discriminator
from .models.generator import Generator
from .models.loss import AdversarialLoss, DiscriminatorLoss, FeautureLoss, GazeLoss, IdentityLoss, MaskLoss, MotionLoss, ReconstructionLoss
from .networks.masknet import MaskNet
from .types import Batch, Embedding, Mask, OptimizerSet
warnings.filterwarnings('ignore', category = UserWarning, module = 'torch')
@@ -42,7 +41,6 @@ class FaceSwapperTrainer(LightningModule):
self.face_parser = torch.jit.load(self.config_face_parser_path, map_location ='cpu').eval()
self.generator = Generator(config_parser)
self.discriminator = Discriminator(config_parser)
self.masker = MaskNet(config_parser)
self.discriminator_loss = DiscriminatorLoss()
self.adversarial_loss = AdversarialLoss(config_parser)
self.feature_loss = FeautureLoss(config_parser)
@@ -55,19 +53,15 @@ class FaceSwapperTrainer(LightningModule):
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Mask]:
with torch.no_grad():
output_tensor, target_features = self.generator(source_embedding, target_tensor)
target_feature = target_features[-1]
output_mask = self.masker(target_tensor, target_feature)
output_tensor, output_mask = self.generator(source_embedding, target_tensor)
return output_tensor, output_mask
def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet, OptimizerSet]:
def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet]:
generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
masker_optimizer = torch.optim.AdamW(self.masker.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
generator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(generator_optimizer, T_0 = 300, T_mult = 2)
discriminator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(discriminator_optimizer, T_0 = 300, T_mult = 2)
masker_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(masker_optimizer, T_0 = 300, T_mult = 2)
generator_config =\
{
@@ -87,24 +81,16 @@ class FaceSwapperTrainer(LightningModule):
'interval': 'step'
}
}
masker_config =\
{
'optimizer': masker_optimizer,
'lr_scheduler':
{
'scheduler': masker_scheduler,
'interval': 'step'
}
}
return generator_config, discriminator_config, masker_config
return generator_config, discriminator_config
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
source_tensor, target_tensor = batch
do_update = (batch_index + 1) % self.config_accumulate_size == 0
generator_optimizer, discriminator_optimizer, masker_optimizer = self.optimizers() #type:ignore[attr-defined]
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
generator_output_tensor, generator_target_features = self.generator(source_embedding, target_tensor)
generator_output_tensor, generator_output_mask = self.generator(source_embedding, target_tensor)
generator_target_features = self.generator.encode_features(target_tensor)
generator_output_features = self.generator.encode_features(generator_output_tensor)
discriminator_output_tensors = self.discriminator(generator_output_tensor)
adversarial_loss, weighted_adversarial_loss = self.adversarial_loss(discriminator_output_tensors)
@@ -113,16 +99,13 @@ class FaceSwapperTrainer(LightningModule):
identity_loss, weighted_identity_loss = self.identity_loss(generator_output_tensor, source_tensor)
pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss = self.motion_loss(target_tensor, generator_output_tensor)
gaze_loss, weighted_gaze_loss = self.gaze_loss(target_tensor, generator_output_tensor)
generator_loss = weighted_adversarial_loss + weighted_feature_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss + weighted_expression_loss
mask_loss, weighted_mask_loss = self.mask_loss(target_tensor, generator_output_mask)
generator_loss = weighted_adversarial_loss + weighted_feature_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss + weighted_expression_loss + weighted_mask_loss
discriminator_source_tensors = self.discriminator(source_tensor)
discriminator_output_tensors = self.discriminator(generator_output_tensor.detach())
discriminator_loss = self.discriminator_loss(discriminator_source_tensors, discriminator_output_tensors)
generator_output_feature = generator_output_features[-1]
generator_output_mask = self.masker(generator_output_tensor.detach(), generator_output_feature.detach())
mask_loss = self.mask_loss(target_tensor, generator_output_mask)
self.toggle_optimizer(generator_optimizer)
self.manual_backward(generator_loss)
if do_update:
@@ -137,13 +120,6 @@ class FaceSwapperTrainer(LightningModule):
discriminator_optimizer.zero_grad()
self.untoggle_optimizer(discriminator_optimizer)
self.toggle_optimizer(masker_optimizer)
self.manual_backward(mask_loss)
if do_update:
masker_optimizer.step()
masker_optimizer.zero_grad()
self.untoggle_optimizer(masker_optimizer)
if self.global_step % self.config_preview_frequency == 0:
self.generate_preview(source_tensor, target_tensor, generator_output_tensor, generator_output_mask)