Train MaskNet based on the output

This commit is contained in:
henryruhs
2025-03-12 21:35:07 +01:00
parent 0732924f1e
commit cf0bd93814
3 changed files with 26 additions and 32 deletions
+3 -3
View File
@@ -38,9 +38,9 @@ def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : P
return embedding
def overlay_mask(target_tensor : Tensor, mask_tensor : Tensor) -> Tensor:
overlay_tensor = torch.zeros(*target_tensor.shape, dtype = target_tensor.dtype, device = target_tensor.device)
def overlay_mask(input_tensor : Tensor, mask_tensor : Tensor) -> Tensor:
overlay_tensor = torch.zeros(*input_tensor.shape, dtype = input_tensor.dtype, device = input_tensor.device)
overlay_tensor[:, 2, :, :] = 1
mask_tensor = mask_tensor.repeat(1, 3, 1, 1).clamp(0, 0.8)
output_tensor = target_tensor * (1 - mask_tensor) + overlay_tensor * mask_tensor
output_tensor = input_tensor * (1 - mask_tensor) + overlay_tensor * mask_tensor
return output_tensor
+2 -2
View File
@@ -34,8 +34,8 @@ class MaskNet(nn.Module):
UpSample(num_filters, num_filters)
])
def forward(self, target_tensor : Tensor, target_attribute : Attribute) -> Tensor:
output_tensor = torch.cat([ target_tensor, target_attribute ], dim = 1)
def forward(self, input_tensor : Tensor, input_attribute : Attribute) -> Tensor:
output_tensor = torch.cat([ input_tensor, input_attribute ], dim = 1)
for down_sample in self.down_samples:
output_tensor = down_sample(output_tensor)
+21 -27
View File
@@ -103,7 +103,6 @@ class FaceSwapperTrainer(LightningModule):
do_update = (batch_index + 1) % self.config_accumulate_size == 0
generator_optimizer, discriminator_optimizer, masker_optimizer = self.optimizers() #type:ignore[attr-defined]
self.toggle_optimizer(generator_optimizer)
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
generator_output_tensor, generator_output_attributes = self.generator(source_embedding, target_tensor)
discriminator_output_tensors = self.discriminator(generator_output_tensor)
@@ -115,40 +114,35 @@ class FaceSwapperTrainer(LightningModule):
gaze_loss, weighted_gaze_loss = self.gaze_loss(target_tensor, generator_output_tensor)
generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss + weighted_expression_loss
self.manual_backward(generator_loss)
if do_update:
generator_optimizer.step()
generator_optimizer.zero_grad()
self.untoggle_optimizer(generator_optimizer)
self.toggle_optimizer(masker_optimizer)
target_attribute = generator_output_attributes[-1]
mask_tensor = self.masker(target_tensor, target_attribute.detach())
mask_loss = self.mask_loss(target_tensor, mask_tensor)
self.manual_backward(mask_loss)
if do_update:
masker_optimizer.step()
masker_optimizer.zero_grad()
self.untoggle_optimizer(masker_optimizer)
self.toggle_optimizer(discriminator_optimizer)
discriminator_source_tensors = self.discriminator(source_tensor)
discriminator_output_tensors = self.discriminator(generator_output_tensor.detach())
discriminator_loss = self.discriminator_loss(discriminator_source_tensors, discriminator_output_tensors)
self.manual_backward(discriminator_loss)
generator_output_attribute = generator_output_attributes[-1]
mask_tensor = self.masker(generator_output_tensor.detach(), generator_output_attribute.detach())
mask_loss = self.mask_loss(generator_output_tensor.detach(), mask_tensor)
self.toggle_optimizer(generator_optimizer)
self.manual_backward(generator_loss, retain_graph = True)
if do_update:
generator_optimizer.step()
generator_optimizer.zero_grad()
self.untoggle_optimizer(generator_optimizer)
self.toggle_optimizer(discriminator_optimizer)
self.manual_backward(discriminator_loss)
if do_update:
discriminator_optimizer.step()
discriminator_optimizer.zero_grad()
self.untoggle_optimizer(discriminator_optimizer)
self.toggle_optimizer(masker_optimizer)
self.manual_backward(mask_loss)
if do_update:
masker_optimizer.step()
masker_optimizer.zero_grad()
self.untoggle_optimizer(masker_optimizer)
if self.global_step % self.config_preview_frequency == 0:
self.generate_preview(source_tensor, target_tensor, generator_output_tensor, mask_tensor)
@@ -175,10 +169,10 @@ class FaceSwapperTrainer(LightningModule):
def generate_preview(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor, mask_tensor : Tensor) -> None:
preview_limit = 8
preview_cells = []
mask_tensor = overlay_mask(target_tensor, mask_tensor)
overlay_tensor = overlay_mask(output_tensor, mask_tensor)
for source_tensor, target_tensor, output_tensor, mask_tensor in zip(source_tensor[:preview_limit], target_tensor[:preview_limit], output_tensor[:preview_limit], mask_tensor[:preview_limit]):
preview_cell = torch.cat([ source_tensor, target_tensor, output_tensor, mask_tensor ], dim = 2)
preview_cell = torch.cat([ source_tensor, target_tensor, output_tensor, overlay_tensor ], dim = 2)
preview_cells.append(preview_cell)
preview_cells = torch.cat(preview_cells, dim = 1).unsqueeze(0)