diff --git a/face_swapper/src/helper.py b/face_swapper/src/helper.py index 352cfd4..98eb971 100644 --- a/face_swapper/src/helper.py +++ b/face_swapper/src/helper.py @@ -38,9 +38,9 @@ def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : P return embedding -def overlay_mask(target_tensor : Tensor, mask_tensor : Tensor) -> Tensor: - overlay_tensor = torch.zeros(*target_tensor.shape, dtype = target_tensor.dtype, device = target_tensor.device) +def overlay_mask(input_tensor : Tensor, mask_tensor : Tensor) -> Tensor: + overlay_tensor = torch.zeros(*input_tensor.shape, dtype = input_tensor.dtype, device = input_tensor.device) overlay_tensor[:, 2, :, :] = 1 mask_tensor = mask_tensor.repeat(1, 3, 1, 1).clamp(0, 0.8) - output_tensor = target_tensor * (1 - mask_tensor) + overlay_tensor * mask_tensor + output_tensor = input_tensor * (1 - mask_tensor) + overlay_tensor * mask_tensor return output_tensor diff --git a/face_swapper/src/networks/masknet.py b/face_swapper/src/networks/masknet.py index 0c43c4d..a26d56b 100644 --- a/face_swapper/src/networks/masknet.py +++ b/face_swapper/src/networks/masknet.py @@ -34,8 +34,8 @@ class MaskNet(nn.Module): UpSample(num_filters, num_filters) ]) - def forward(self, target_tensor : Tensor, target_attribute : Attribute) -> Tensor: - output_tensor = torch.cat([ target_tensor, target_attribute ], dim = 1) + def forward(self, input_tensor : Tensor, input_attribute : Attribute) -> Tensor: + output_tensor = torch.cat([ input_tensor, input_attribute ], dim = 1) for down_sample in self.down_samples: output_tensor = down_sample(output_tensor) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 4b9a261..fca3105 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -103,7 +103,6 @@ class FaceSwapperTrainer(LightningModule): do_update = (batch_index + 1) % self.config_accumulate_size == 0 generator_optimizer, discriminator_optimizer, masker_optimizer = self.optimizers() #type:ignore[attr-defined] - self.toggle_optimizer(generator_optimizer) source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0)) generator_output_tensor, generator_output_attributes = self.generator(source_embedding, target_tensor) discriminator_output_tensors = self.discriminator(generator_output_tensor) @@ -115,40 +114,35 @@ class FaceSwapperTrainer(LightningModule): gaze_loss, weighted_gaze_loss = self.gaze_loss(target_tensor, generator_output_tensor) generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss + weighted_expression_loss - self.manual_backward(generator_loss) - - if do_update: - generator_optimizer.step() - generator_optimizer.zero_grad() - - self.untoggle_optimizer(generator_optimizer) - - self.toggle_optimizer(masker_optimizer) - target_attribute = generator_output_attributes[-1] - mask_tensor = self.masker(target_tensor, target_attribute.detach()) - mask_loss = self.mask_loss(target_tensor, mask_tensor) - - self.manual_backward(mask_loss) - - if do_update: - masker_optimizer.step() - masker_optimizer.zero_grad() - - self.untoggle_optimizer(masker_optimizer) - - self.toggle_optimizer(discriminator_optimizer) discriminator_source_tensors = self.discriminator(source_tensor) discriminator_output_tensors = self.discriminator(generator_output_tensor.detach()) discriminator_loss = self.discriminator_loss(discriminator_source_tensors, discriminator_output_tensors) - self.manual_backward(discriminator_loss) + generator_output_attribute = generator_output_attributes[-1] + mask_tensor = self.masker(generator_output_tensor.detach(), generator_output_attribute.detach()) + mask_loss = self.mask_loss(generator_output_tensor.detach(), mask_tensor) + self.toggle_optimizer(generator_optimizer) + self.manual_backward(generator_loss, retain_graph = True) + if do_update: + generator_optimizer.step() + generator_optimizer.zero_grad() + self.untoggle_optimizer(generator_optimizer) + + self.toggle_optimizer(discriminator_optimizer) + self.manual_backward(discriminator_loss) if do_update: discriminator_optimizer.step() discriminator_optimizer.zero_grad() - self.untoggle_optimizer(discriminator_optimizer) + self.toggle_optimizer(masker_optimizer) + self.manual_backward(mask_loss) + if do_update: + masker_optimizer.step() + masker_optimizer.zero_grad() + self.untoggle_optimizer(masker_optimizer) + if self.global_step % self.config_preview_frequency == 0: self.generate_preview(source_tensor, target_tensor, generator_output_tensor, mask_tensor) @@ -175,10 +169,10 @@ class FaceSwapperTrainer(LightningModule): def generate_preview(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor, mask_tensor : Tensor) -> None: preview_limit = 8 preview_cells = [] - mask_tensor = overlay_mask(target_tensor, mask_tensor) + overlay_tensor = overlay_mask(output_tensor, mask_tensor) for source_tensor, target_tensor, output_tensor, mask_tensor in zip(source_tensor[:preview_limit], target_tensor[:preview_limit], output_tensor[:preview_limit], mask_tensor[:preview_limit]): - preview_cell = torch.cat([ source_tensor, target_tensor, output_tensor, mask_tensor ], dim = 2) + preview_cell = torch.cat([ source_tensor, target_tensor, output_tensor, overlay_tensor ], dim = 2) preview_cells.append(preview_cell) preview_cells = torch.cat(preview_cells, dim = 1).unsqueeze(0)