mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Introduce new AttributeLoss class
This commit is contained in:
@@ -40,9 +40,9 @@ class EmbeddingConverterTrainer(lightning.LightningModule):
|
||||
source_embedding = self.source_embedder(batch)
|
||||
target_embedding = self.target_embedder(batch)
|
||||
output_embedding = self(source_embedding)
|
||||
loss_training = self.mse_loss(output_embedding, target_embedding)
|
||||
self.log('loss_training', loss_training, prog_bar = True)
|
||||
return loss_training
|
||||
training_loss = self.mse_loss(output_embedding, target_embedding)
|
||||
self.log('training_loss', training_loss, prog_bar = True)
|
||||
return training_loss
|
||||
|
||||
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
with torch.no_grad():
|
||||
@@ -63,7 +63,7 @@ class EmbeddingConverterTrainer(lightning.LightningModule):
|
||||
'lr_scheduler':
|
||||
{
|
||||
'scheduler': scheduler,
|
||||
'monitor': 'loss_training',
|
||||
'monitor': 'training_loss',
|
||||
'interval': 'epoch',
|
||||
'frequency': 1
|
||||
}
|
||||
@@ -102,7 +102,7 @@ def create_trainer() -> Trainer:
|
||||
callbacks =
|
||||
[
|
||||
ModelCheckpoint(
|
||||
monitor = 'loss_training',
|
||||
monitor = 'training_loss',
|
||||
dirpath = output_directory_path,
|
||||
filename = output_file_pattern,
|
||||
every_n_epochs = 10,
|
||||
|
||||
@@ -65,8 +65,8 @@ kernel_size = 4
|
||||
[training.losses]
|
||||
adversarial_weight = 1.5
|
||||
attribute_weight = 10
|
||||
reconstruction_weight = 15
|
||||
identity_weight = 15
|
||||
reconstruction_weight = 20
|
||||
identity_weight = 20
|
||||
pose_weight = 0
|
||||
gaze_weight = 0
|
||||
```
|
||||
|
||||
@@ -3,10 +3,11 @@ from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from pytorch_msssim import ssim
|
||||
from sqlalchemy.dialects.mssql.information_schema import identity_columns
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..helper import calc_embedding, hinge_fake_loss, hinge_real_loss
|
||||
from ..types import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor
|
||||
from ..types import Attributes, Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
@@ -153,9 +154,27 @@ class AdversarialLoss(torch.nn.Module):
|
||||
temp_tensor = torch.relu(1 - discriminator_output_tensor[0]).mean(dim = [ 1, 2, 3 ]).mean()
|
||||
temp_tensors.append(temp_tensor)
|
||||
|
||||
loss = torch.stack(temp_tensors).mean()
|
||||
weighted_loss = loss * adversarial_weight
|
||||
return loss, weighted_loss
|
||||
adversarial_loss = torch.stack(temp_tensors).mean()
|
||||
weighted_adversarial_loss = adversarial_loss * adversarial_weight
|
||||
return adversarial_loss, weighted_adversarial_loss
|
||||
|
||||
|
||||
class AttributeLoss(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super(AttributeLoss, self).__init__()
|
||||
|
||||
def calc(self, target_attributes : Attributes, output_attributes : Attributes) -> Tuple[Tensor, Tensor]:
|
||||
batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
attribute_weight = CONFIG.getfloat('training.losses', 'attribute_weight')
|
||||
temp_tensors = []
|
||||
|
||||
for target_attribute, output_attribute in zip(target_attributes, output_attributes):
|
||||
temp_tensor = torch.mean(torch.pow(output_attribute - target_attribute, 2).reshape(batch_size, -1), dim = 1).mean()
|
||||
temp_tensors.append(temp_tensor)
|
||||
|
||||
attribute_loss = torch.stack(temp_tensors).mean() * 0.5
|
||||
weighted_attribute_loss = attribute_loss * attribute_weight
|
||||
return attribute_loss, weighted_attribute_loss
|
||||
|
||||
|
||||
class ReconstructionLoss(torch.nn.Module):
|
||||
@@ -165,20 +184,20 @@ class ReconstructionLoss(torch.nn.Module):
|
||||
def calc(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
reconstruction_weight = CONFIG.getfloat('training.losses', 'reconstruction_weight')
|
||||
loss = torch.pow(output_tensor - target_tensor, 2).reshape(batch_size, -1)
|
||||
loss = torch.mean(loss, dim = 1) * 0.5
|
||||
reconstruction_loss = torch.pow(output_tensor - target_tensor, 2).reshape(batch_size, -1)
|
||||
reconstruction_loss = torch.mean(reconstruction_loss, dim = 1) * 0.5
|
||||
|
||||
if torch.equal(source_tensor, target_tensor):
|
||||
loss = torch.sum(loss * torch.tensor(0)) / (torch.tensor(0).sum() + 1e-4)
|
||||
reconstruction_loss = torch.sum(reconstruction_loss * torch.tensor(0)) / (torch.tensor(0).sum() + 1e-4)
|
||||
else:
|
||||
loss = torch.sum(loss * torch.tensor(1)) / (torch.tensor(1).sum() + 1e-4)
|
||||
reconstruction_loss = torch.sum(reconstruction_loss * torch.tensor(1)) / (torch.tensor(1).sum() + 1e-4)
|
||||
|
||||
data_range = float(torch.max(output_tensor) - torch.min(output_tensor))
|
||||
similarity = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean()
|
||||
|
||||
loss = (loss + similarity) * 0.5
|
||||
weighted_loss = loss * reconstruction_weight
|
||||
return loss, weighted_loss
|
||||
reconstruction_loss = (reconstruction_loss + similarity) * 0.5
|
||||
weighted_reconstruction_loss = reconstruction_loss * reconstruction_weight
|
||||
return reconstruction_loss, weighted_reconstruction_loss
|
||||
|
||||
|
||||
class IdentityLoss(torch.nn.Module):
|
||||
@@ -192,6 +211,6 @@ class IdentityLoss(torch.nn.Module):
|
||||
identity_weight = CONFIG.getfloat('training.losses', 'identity_weight')
|
||||
output_embedding = calc_embedding(self.embedder, output_tensor, (30, 0, 10, 10))
|
||||
source_embedding = calc_embedding(self.embedder, source_tensor, (30, 0, 10, 10))
|
||||
loss = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean()
|
||||
weighted_loss = loss * identity_weight
|
||||
return loss, weighted_loss
|
||||
identity_loss = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean()
|
||||
weighted_identity_loss = identity_loss * identity_weight
|
||||
return identity_loss, weighted_identity_loss
|
||||
|
||||
@@ -16,7 +16,7 @@ from .dataset import DynamicDataset
|
||||
from .helper import calc_embedding
|
||||
from .models.discriminator import Discriminator
|
||||
from .models.generator import Generator
|
||||
from .models.loss import AdversarialLoss, FaceSwapperLoss, IdentityLoss, ReconstructionLoss
|
||||
from .models.loss import AdversarialLoss, AttributeLoss, FaceSwapperLoss, IdentityLoss, ReconstructionLoss
|
||||
from .types import Batch, Embedding, VisionTensor
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
@@ -32,6 +32,7 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
|
||||
self.generator = Generator()
|
||||
self.discriminator = Discriminator()
|
||||
self.adversarial_loss = AdversarialLoss()
|
||||
self.attribute_loss = AttributeLoss()
|
||||
self.reconstruction_loss = ReconstructionLoss()
|
||||
self.identity_loss = IdentityLoss()
|
||||
self.automatic_optimization = automatic_optimization
|
||||
@@ -74,23 +75,25 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
|
||||
self.generate_preview(source_tensor, target_tensor, generator_output_tensor)
|
||||
|
||||
self.log('loss_generator', generator_loss_set.get('loss_generator'), prog_bar = True)
|
||||
self.log('loss_discriminator', discriminator_loss_set.get('loss_discriminator'), prog_bar = True)
|
||||
self.log('loss_discriminator', discriminator_loss_set.get('loss_discriminator'))
|
||||
self.log('loss_adversarial', generator_loss_set.get('loss_adversarial'), prog_bar = True)
|
||||
self.log('loss_attribute', generator_loss_set.get('loss_attribute'))
|
||||
self.log('loss_attribute', generator_loss_set.get('loss_attribute'), prog_bar = True)
|
||||
self.log('loss_identity', generator_loss_set.get('loss_identity'))
|
||||
self.log('loss_reconstruction', generator_loss_set.get('loss_reconstruction'))
|
||||
|
||||
###############################################
|
||||
|
||||
adversarial_loss, weighted_adversarial_loss = self.adversarial_loss.calc(discriminator_output_tensors)
|
||||
attribute_loss, weighted_attribute_loss = self.attribute_loss.calc(target_attributes, generator_output_attributes)
|
||||
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss.calc(source_tensor, target_tensor, generator_output_tensor)
|
||||
identity_loss, weighted_identity_loss = self.identity_loss.calc(generator_output_tensor, source_tensor)
|
||||
generator_loss = weighted_adversarial_loss + weighted_reconstruction_loss + weighted_identity_loss
|
||||
generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss
|
||||
|
||||
self.log('generator_loss_new', generator_loss, prog_bar = True)
|
||||
self.log('loss_adversarial_new', adversarial_loss, prog_bar = True)
|
||||
self.log('loss_reconstruction_new', reconstruction_loss)
|
||||
self.log('loss_identity_new', identity_loss)
|
||||
self.log('adversarial_loss_new', adversarial_loss)
|
||||
self.log('attribute_loss_new', attribute_loss, prog_bar = True)
|
||||
self.log('reconstruction_loss_new', reconstruction_loss)
|
||||
self.log('identity_loss_new', identity_loss)
|
||||
return generator_loss_set.get('loss_generator')
|
||||
|
||||
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user