mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Merge pull request #58 from facefusion/refactor/generator-return-attributes
Let the generator return target attributes
This commit is contained in:
@@ -16,13 +16,10 @@ class Generator(nn.Module):
|
||||
self.encoder.apply(init_weight)
|
||||
self.generator.apply(init_weight)
|
||||
|
||||
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tensor:
|
||||
target_attributes = self.get_attributes(target_tensor)
|
||||
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Tuple[Attribute, ...]]:
|
||||
target_attributes = self.encoder(target_tensor)
|
||||
output_tensor = self.generator(source_embedding, target_attributes)
|
||||
return output_tensor
|
||||
|
||||
def get_attributes(self, input_tensor : Tensor) -> Tuple[Attribute, ...]:
|
||||
return self.encoder(input_tensor)
|
||||
return output_tensor, target_attributes
|
||||
|
||||
|
||||
def init_weight(module : nn.Module) -> None:
|
||||
|
||||
@@ -3,7 +3,7 @@ from configparser import ConfigParser
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from face_swapper.src.types import Attribute
|
||||
from ..types import Attribute
|
||||
|
||||
|
||||
class MaskNet(nn.Module):
|
||||
|
||||
@@ -55,8 +55,8 @@ class FaceSwapperTrainer(LightningModule):
|
||||
|
||||
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
with torch.no_grad():
|
||||
output_tensor = self.generator(source_embedding, target_tensor)
|
||||
target_attribute = self.generator.get_attributes(target_tensor)[-1]
|
||||
output_tensor, target_attributes = self.generator(source_embedding, target_tensor)
|
||||
target_attribute = target_attributes[-1]
|
||||
mask_tensor = self.masker(target_tensor, target_attribute)
|
||||
|
||||
return output_tensor, mask_tensor
|
||||
@@ -105,12 +105,10 @@ class FaceSwapperTrainer(LightningModule):
|
||||
|
||||
self.toggle_optimizer(generator_optimizer)
|
||||
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
|
||||
target_attributes = self.generator.get_attributes(target_tensor)
|
||||
generator_output_tensor = self.generator(source_embedding, target_tensor)
|
||||
generator_output_attributes = self.generator.get_attributes(generator_output_tensor)
|
||||
generator_output_tensor, generator_output_attributes = self.generator(source_embedding, target_tensor)
|
||||
discriminator_output_tensors = self.discriminator(generator_output_tensor)
|
||||
adversarial_loss, weighted_adversarial_loss = self.adversarial_loss(discriminator_output_tensors)
|
||||
attribute_loss, weighted_attribute_loss = self.attribute_loss(target_attributes, generator_output_attributes)
|
||||
attribute_loss, weighted_attribute_loss = self.attribute_loss(generator_output_attributes, generator_output_attributes)
|
||||
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss(source_tensor, target_tensor, generator_output_tensor)
|
||||
identity_loss, weighted_identity_loss = self.identity_loss(generator_output_tensor, source_tensor)
|
||||
pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss = self.motion_loss(target_tensor, generator_output_tensor)
|
||||
@@ -126,7 +124,7 @@ class FaceSwapperTrainer(LightningModule):
|
||||
self.untoggle_optimizer(generator_optimizer)
|
||||
|
||||
self.toggle_optimizer(masker_optimizer)
|
||||
target_attribute = target_attributes[-1].detach()
|
||||
target_attribute = generator_output_attributes[-1].detach()
|
||||
mask_tensor = self.masker(target_tensor, target_attribute)
|
||||
mask_loss = self.mask_loss(target_tensor, mask_tensor)
|
||||
|
||||
@@ -168,7 +166,7 @@ class FaceSwapperTrainer(LightningModule):
|
||||
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
source_tensor, target_tensor = batch
|
||||
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
|
||||
output_tensor = self.generator(source_embedding, target_tensor)
|
||||
output_tensor, _ = self.generator(source_embedding, target_tensor)
|
||||
output_embedding = calc_embedding(self.embedder, output_tensor, (0, 0, 0, 0))
|
||||
validation_score = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5
|
||||
self.log('validation_score', validation_score, prog_bar = True)
|
||||
|
||||
Reference in New Issue
Block a user