mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
cleaning
This commit is contained in:
@@ -17,7 +17,7 @@ from .data_loader import DataLoaderVGG
|
||||
from .discriminator import MultiscaleDiscriminator
|
||||
from .generator import AdaptiveEmbeddingIntegrationNetwork
|
||||
from .helper import hinge_fake_loss, hinge_real_loss
|
||||
from .typing import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, IdEmbedding, Loss, Padding, SourceEmbedding, TargetAttributes, VisionTensor
|
||||
from .typing import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, IdEmbedding, LossTensor, Padding, SourceEmbedding, TargetAttributes, VisionTensor
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
@@ -88,7 +88,7 @@ class FaceSwapper(pytorch_lightning.LightningModule):
|
||||
self.log('l_REC', generator_losses.get('loss_reconstruction'), prog_bar = True)
|
||||
return generator_losses.get('loss_generator')
|
||||
|
||||
def calc_adversarial_loss(self, discriminator_outputs : DiscriminatorOutputs) -> Loss:
|
||||
def calc_adversarial_loss(self, discriminator_outputs : DiscriminatorOutputs) -> LossTensor:
|
||||
loss_adversarial = torch.Tensor(0)
|
||||
|
||||
for discriminator_output in discriminator_outputs:
|
||||
@@ -96,7 +96,7 @@ class FaceSwapper(pytorch_lightning.LightningModule):
|
||||
loss_adversarial = torch.mean(loss_adversarial)
|
||||
return loss_adversarial
|
||||
|
||||
def calc_attribute_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes) -> Loss:
|
||||
def calc_attribute_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes) -> LossTensor:
|
||||
loss_attribute = torch.Tensor(0)
|
||||
swap_attributes = self.generator.get_attributes(swap_tensor)
|
||||
|
||||
@@ -105,19 +105,19 @@ class FaceSwapper(pytorch_lightning.LightningModule):
|
||||
loss_attribute *= 0.5
|
||||
return loss_attribute
|
||||
|
||||
def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> Loss:
|
||||
def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> LossTensor:
|
||||
loss_reconstruction = torch.sum(0.5 * torch.mean(torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1), dim = 1) * is_same_person) / (is_same_person.sum() + 1e-4)
|
||||
loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean()
|
||||
loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5
|
||||
return loss_reconstruction
|
||||
|
||||
def calc_id_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> Loss:
|
||||
def calc_id_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor:
|
||||
swap_embedding = self.get_id_embedding(swap_tensor, (30, 0, 10, 10))
|
||||
source_embedding = self.get_id_embedding(source_tensor, (30, 0, 10, 10))
|
||||
loss_id = (1 - torch.cosine_similarity(source_embedding, swap_embedding, dim = 1)).mean()
|
||||
return loss_id
|
||||
|
||||
def calc_tsr_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> Loss:
|
||||
def calc_tsr_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor:
|
||||
swap_motion_features = self.get_pose_features(swap_tensor)
|
||||
target_motion_features = self.get_pose_features(target_tensor)
|
||||
loss_tsr = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
|
||||
@@ -126,7 +126,7 @@ class FaceSwapper(pytorch_lightning.LightningModule):
|
||||
loss_tsr += self.mse_loss(swap_motion_feature, target_motion_feature)
|
||||
return loss_tsr
|
||||
|
||||
def calc_gaze_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> Loss:
|
||||
def calc_gaze_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor:
|
||||
swap_landmark = self.get_face_landmarks(swap_tensor)
|
||||
target_landmark = self.get_face_landmarks(target_tensor)
|
||||
left_gaze_loss = self.mse_loss(swap_landmark[:, 198], target_landmark[:, 198])
|
||||
|
||||
@@ -14,6 +14,6 @@ StateDict = OrderedDict[str, Any]
|
||||
Padding = Tuple[int, int, int, int]
|
||||
FaceLandmark203 = Tensor
|
||||
VisionTensor = Tensor
|
||||
Loss = Tensor
|
||||
GeneratorLossSet = Dict[str, Loss]
|
||||
DiscriminatorLossSet = Dict[str, Loss]
|
||||
LossTensor = Tensor
|
||||
GeneratorLossSet = Dict[str, LossTensor]
|
||||
DiscriminatorLossSet = Dict[str, LossTensor]
|
||||
|
||||
Reference in New Issue
Block a user