Introduce new AttributeLoss class

This commit is contained in:
henryruhs
2025-02-22 22:31:44 +01:00
parent 7848d28b02
commit 6eabcad1d0
4 changed files with 50 additions and 28 deletions
+5 -5
View File
@@ -40,9 +40,9 @@ class EmbeddingConverterTrainer(lightning.LightningModule):
source_embedding = self.source_embedder(batch)
target_embedding = self.target_embedder(batch)
output_embedding = self(source_embedding)
loss_training = self.mse_loss(output_embedding, target_embedding)
self.log('loss_training', loss_training, prog_bar = True)
return loss_training
training_loss = self.mse_loss(output_embedding, target_embedding)
self.log('training_loss', training_loss, prog_bar = True)
return training_loss
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
with torch.no_grad():
@@ -63,7 +63,7 @@ class EmbeddingConverterTrainer(lightning.LightningModule):
'lr_scheduler':
{
'scheduler': scheduler,
'monitor': 'loss_training',
'monitor': 'training_loss',
'interval': 'epoch',
'frequency': 1
}
@@ -102,7 +102,7 @@ def create_trainer() -> Trainer:
callbacks =
[
ModelCheckpoint(
monitor = 'loss_training',
monitor = 'training_loss',
dirpath = output_directory_path,
filename = output_file_pattern,
every_n_epochs = 10,
+2 -2
View File
@@ -65,8 +65,8 @@ kernel_size = 4
[training.losses]
adversarial_weight = 1.5
attribute_weight = 10
reconstruction_weight = 15
identity_weight = 15
reconstruction_weight = 20
identity_weight = 20
pose_weight = 0
gaze_weight = 0
```
+33 -14
View File
@@ -3,10 +3,11 @@ from typing import List, Tuple
import torch
from pytorch_msssim import ssim
from sqlalchemy.dialects.mssql.information_schema import identity_columns
from torch import Tensor, nn
from ..helper import calc_embedding, hinge_fake_loss, hinge_real_loss
from ..types import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor
from ..types import Attributes, Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
@@ -153,9 +154,27 @@ class AdversarialLoss(torch.nn.Module):
temp_tensor = torch.relu(1 - discriminator_output_tensor[0]).mean(dim = [ 1, 2, 3 ]).mean()
temp_tensors.append(temp_tensor)
loss = torch.stack(temp_tensors).mean()
weighted_loss = loss * adversarial_weight
return loss, weighted_loss
adversarial_loss = torch.stack(temp_tensors).mean()
weighted_adversarial_loss = adversarial_loss * adversarial_weight
return adversarial_loss, weighted_adversarial_loss
class AttributeLoss(torch.nn.Module):
def __init__(self) -> None:
super(AttributeLoss, self).__init__()
def calc(self, target_attributes : Attributes, output_attributes : Attributes) -> Tuple[Tensor, Tensor]:
batch_size = CONFIG.getint('training.loader', 'batch_size')
attribute_weight = CONFIG.getfloat('training.losses', 'attribute_weight')
temp_tensors = []
for target_attribute, output_attribute in zip(target_attributes, output_attributes):
temp_tensor = torch.mean(torch.pow(output_attribute - target_attribute, 2).reshape(batch_size, -1), dim = 1).mean()
temp_tensors.append(temp_tensor)
attribute_loss = torch.stack(temp_tensors).mean() * 0.5
weighted_attribute_loss = attribute_loss * attribute_weight
return attribute_loss, weighted_attribute_loss
class ReconstructionLoss(torch.nn.Module):
@@ -165,20 +184,20 @@ class ReconstructionLoss(torch.nn.Module):
def calc(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
batch_size = CONFIG.getint('training.loader', 'batch_size')
reconstruction_weight = CONFIG.getfloat('training.losses', 'reconstruction_weight')
loss = torch.pow(output_tensor - target_tensor, 2).reshape(batch_size, -1)
loss = torch.mean(loss, dim = 1) * 0.5
reconstruction_loss = torch.pow(output_tensor - target_tensor, 2).reshape(batch_size, -1)
reconstruction_loss = torch.mean(reconstruction_loss, dim = 1) * 0.5
if torch.equal(source_tensor, target_tensor):
loss = torch.sum(loss * torch.tensor(0)) / (torch.tensor(0).sum() + 1e-4)
reconstruction_loss = torch.sum(reconstruction_loss * torch.tensor(0)) / (torch.tensor(0).sum() + 1e-4)
else:
loss = torch.sum(loss * torch.tensor(1)) / (torch.tensor(1).sum() + 1e-4)
reconstruction_loss = torch.sum(reconstruction_loss * torch.tensor(1)) / (torch.tensor(1).sum() + 1e-4)
data_range = float(torch.max(output_tensor) - torch.min(output_tensor))
similarity = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean()
loss = (loss + similarity) * 0.5
weighted_loss = loss * reconstruction_weight
return loss, weighted_loss
reconstruction_loss = (reconstruction_loss + similarity) * 0.5
weighted_reconstruction_loss = reconstruction_loss * reconstruction_weight
return reconstruction_loss, weighted_reconstruction_loss
class IdentityLoss(torch.nn.Module):
@@ -192,6 +211,6 @@ class IdentityLoss(torch.nn.Module):
identity_weight = CONFIG.getfloat('training.losses', 'identity_weight')
output_embedding = calc_embedding(self.embedder, output_tensor, (30, 0, 10, 10))
source_embedding = calc_embedding(self.embedder, source_tensor, (30, 0, 10, 10))
loss = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean()
weighted_loss = loss * identity_weight
return loss, weighted_loss
identity_loss = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean()
weighted_identity_loss = identity_loss * identity_weight
return identity_loss, weighted_identity_loss
+10 -7
View File
@@ -16,7 +16,7 @@ from .dataset import DynamicDataset
from .helper import calc_embedding
from .models.discriminator import Discriminator
from .models.generator import Generator
from .models.loss import AdversarialLoss, FaceSwapperLoss, IdentityLoss, ReconstructionLoss
from .models.loss import AdversarialLoss, AttributeLoss, FaceSwapperLoss, IdentityLoss, ReconstructionLoss
from .types import Batch, Embedding, VisionTensor
CONFIG = configparser.ConfigParser()
@@ -32,6 +32,7 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
self.generator = Generator()
self.discriminator = Discriminator()
self.adversarial_loss = AdversarialLoss()
self.attribute_loss = AttributeLoss()
self.reconstruction_loss = ReconstructionLoss()
self.identity_loss = IdentityLoss()
self.automatic_optimization = automatic_optimization
@@ -74,23 +75,25 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
self.generate_preview(source_tensor, target_tensor, generator_output_tensor)
self.log('loss_generator', generator_loss_set.get('loss_generator'), prog_bar = True)
self.log('loss_discriminator', discriminator_loss_set.get('loss_discriminator'), prog_bar = True)
self.log('loss_discriminator', discriminator_loss_set.get('loss_discriminator'))
self.log('loss_adversarial', generator_loss_set.get('loss_adversarial'), prog_bar = True)
self.log('loss_attribute', generator_loss_set.get('loss_attribute'))
self.log('loss_attribute', generator_loss_set.get('loss_attribute'), prog_bar = True)
self.log('loss_identity', generator_loss_set.get('loss_identity'))
self.log('loss_reconstruction', generator_loss_set.get('loss_reconstruction'))
###############################################
adversarial_loss, weighted_adversarial_loss = self.adversarial_loss.calc(discriminator_output_tensors)
attribute_loss, weighted_attribute_loss = self.attribute_loss.calc(target_attributes, generator_output_attributes)
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss.calc(source_tensor, target_tensor, generator_output_tensor)
identity_loss, weighted_identity_loss = self.identity_loss.calc(generator_output_tensor, source_tensor)
generator_loss = weighted_adversarial_loss + weighted_reconstruction_loss + weighted_identity_loss
generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss
self.log('generator_loss_new', generator_loss, prog_bar = True)
self.log('loss_adversarial_new', adversarial_loss, prog_bar = True)
self.log('loss_reconstruction_new', reconstruction_loss)
self.log('loss_identity_new', identity_loss)
self.log('adversarial_loss_new', adversarial_loss)
self.log('attribute_loss_new', attribute_loss, prog_bar = True)
self.log('reconstruction_loss_new', reconstruction_loss)
self.log('identity_loss_new', identity_loss)
return generator_loss_set.get('loss_generator')
def validation_step(self, batch : Batch, batch_index : int) -> Tensor: