mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
Train MaskNet based on the output
This commit is contained in:
@@ -38,9 +38,9 @@ def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : P
|
||||
return embedding
|
||||
|
||||
|
||||
def overlay_mask(target_tensor : Tensor, mask_tensor : Tensor) -> Tensor:
|
||||
overlay_tensor = torch.zeros(*target_tensor.shape, dtype = target_tensor.dtype, device = target_tensor.device)
|
||||
def overlay_mask(input_tensor : Tensor, mask_tensor : Tensor) -> Tensor:
|
||||
overlay_tensor = torch.zeros(*input_tensor.shape, dtype = input_tensor.dtype, device = input_tensor.device)
|
||||
overlay_tensor[:, 2, :, :] = 1
|
||||
mask_tensor = mask_tensor.repeat(1, 3, 1, 1).clamp(0, 0.8)
|
||||
output_tensor = target_tensor * (1 - mask_tensor) + overlay_tensor * mask_tensor
|
||||
output_tensor = input_tensor * (1 - mask_tensor) + overlay_tensor * mask_tensor
|
||||
return output_tensor
|
||||
|
||||
@@ -34,8 +34,8 @@ class MaskNet(nn.Module):
|
||||
UpSample(num_filters, num_filters)
|
||||
])
|
||||
|
||||
def forward(self, target_tensor : Tensor, target_attribute : Attribute) -> Tensor:
|
||||
output_tensor = torch.cat([ target_tensor, target_attribute ], dim = 1)
|
||||
def forward(self, input_tensor : Tensor, input_attribute : Attribute) -> Tensor:
|
||||
output_tensor = torch.cat([ input_tensor, input_attribute ], dim = 1)
|
||||
|
||||
for down_sample in self.down_samples:
|
||||
output_tensor = down_sample(output_tensor)
|
||||
|
||||
@@ -103,7 +103,6 @@ class FaceSwapperTrainer(LightningModule):
|
||||
do_update = (batch_index + 1) % self.config_accumulate_size == 0
|
||||
generator_optimizer, discriminator_optimizer, masker_optimizer = self.optimizers() #type:ignore[attr-defined]
|
||||
|
||||
self.toggle_optimizer(generator_optimizer)
|
||||
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
|
||||
generator_output_tensor, generator_output_attributes = self.generator(source_embedding, target_tensor)
|
||||
discriminator_output_tensors = self.discriminator(generator_output_tensor)
|
||||
@@ -115,40 +114,35 @@ class FaceSwapperTrainer(LightningModule):
|
||||
gaze_loss, weighted_gaze_loss = self.gaze_loss(target_tensor, generator_output_tensor)
|
||||
generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss + weighted_expression_loss
|
||||
|
||||
self.manual_backward(generator_loss)
|
||||
|
||||
if do_update:
|
||||
generator_optimizer.step()
|
||||
generator_optimizer.zero_grad()
|
||||
|
||||
self.untoggle_optimizer(generator_optimizer)
|
||||
|
||||
self.toggle_optimizer(masker_optimizer)
|
||||
target_attribute = generator_output_attributes[-1]
|
||||
mask_tensor = self.masker(target_tensor, target_attribute.detach())
|
||||
mask_loss = self.mask_loss(target_tensor, mask_tensor)
|
||||
|
||||
self.manual_backward(mask_loss)
|
||||
|
||||
if do_update:
|
||||
masker_optimizer.step()
|
||||
masker_optimizer.zero_grad()
|
||||
|
||||
self.untoggle_optimizer(masker_optimizer)
|
||||
|
||||
self.toggle_optimizer(discriminator_optimizer)
|
||||
discriminator_source_tensors = self.discriminator(source_tensor)
|
||||
discriminator_output_tensors = self.discriminator(generator_output_tensor.detach())
|
||||
discriminator_loss = self.discriminator_loss(discriminator_source_tensors, discriminator_output_tensors)
|
||||
|
||||
self.manual_backward(discriminator_loss)
|
||||
generator_output_attribute = generator_output_attributes[-1]
|
||||
mask_tensor = self.masker(generator_output_tensor.detach(), generator_output_attribute.detach())
|
||||
mask_loss = self.mask_loss(generator_output_tensor.detach(), mask_tensor)
|
||||
|
||||
self.toggle_optimizer(generator_optimizer)
|
||||
self.manual_backward(generator_loss, retain_graph = True)
|
||||
if do_update:
|
||||
generator_optimizer.step()
|
||||
generator_optimizer.zero_grad()
|
||||
self.untoggle_optimizer(generator_optimizer)
|
||||
|
||||
self.toggle_optimizer(discriminator_optimizer)
|
||||
self.manual_backward(discriminator_loss)
|
||||
if do_update:
|
||||
discriminator_optimizer.step()
|
||||
discriminator_optimizer.zero_grad()
|
||||
|
||||
self.untoggle_optimizer(discriminator_optimizer)
|
||||
|
||||
self.toggle_optimizer(masker_optimizer)
|
||||
self.manual_backward(mask_loss)
|
||||
if do_update:
|
||||
masker_optimizer.step()
|
||||
masker_optimizer.zero_grad()
|
||||
self.untoggle_optimizer(masker_optimizer)
|
||||
|
||||
if self.global_step % self.config_preview_frequency == 0:
|
||||
self.generate_preview(source_tensor, target_tensor, generator_output_tensor, mask_tensor)
|
||||
|
||||
@@ -175,10 +169,10 @@ class FaceSwapperTrainer(LightningModule):
|
||||
def generate_preview(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor, mask_tensor : Tensor) -> None:
|
||||
preview_limit = 8
|
||||
preview_cells = []
|
||||
mask_tensor = overlay_mask(target_tensor, mask_tensor)
|
||||
overlay_tensor = overlay_mask(output_tensor, mask_tensor)
|
||||
|
||||
for source_tensor, target_tensor, output_tensor, mask_tensor in zip(source_tensor[:preview_limit], target_tensor[:preview_limit], output_tensor[:preview_limit], mask_tensor[:preview_limit]):
|
||||
preview_cell = torch.cat([ source_tensor, target_tensor, output_tensor, mask_tensor ], dim = 2)
|
||||
preview_cell = torch.cat([ source_tensor, target_tensor, output_tensor, overlay_tensor ], dim = 2)
|
||||
preview_cells.append(preview_cell)
|
||||
|
||||
preview_cells = torch.cat(preview_cells, dim = 1).unsqueeze(0)
|
||||
|
||||
Reference in New Issue
Block a user