Introduce new ReconstructionLoss class

This commit is contained in:
henryruhs
2025-02-22 14:29:16 +01:00
parent 085c493e18
commit 086d9eed87
3 changed files with 46 additions and 17 deletions
+6 -6
View File
@@ -63,12 +63,12 @@ kernel_size = 4
```
[training.losses]
weight_adversarial = 1.5
weight_identity = 15
weight_attribute = 10
weight_reconstruction = 15
weight_pose = 0
weight_gaze = 0
adversarial_weight = 1.5
attribute_weight = 10
reconstruction_weight = 15
identity_weight = 15
pose_weight = 0
gaze_weight = 0
```
```
+25 -3
View File
@@ -141,6 +141,28 @@ class FaceSwapperLoss:
return translation, scale, rotation
class ReconstructionLoss(torch.nn.Module):
def __init__(self) -> None:
super(ReconstructionLoss, self).__init__()
def calc(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tensor:
batch_size = CONFIG.getint('training.loader', 'batch_size')
loss_tensor = torch.pow(output_tensor - target_tensor, 2).reshape(batch_size, -1)
loss_tensor = torch.mean(loss_tensor, dim = 1) * 0.5
if torch.equal(source_tensor, target_tensor):
loss_tensor = torch.sum(loss_tensor * torch.tensor(0)) / (torch.tensor(0).sum() + 1e-4)
else:
loss_tensor = torch.sum(loss_tensor * torch.tensor(1)) / (torch.tensor(1).sum() + 1e-4)
data_range = float(torch.max(output_tensor) - torch.min(output_tensor))
similarity = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean()
loss_tensor = (loss_tensor + similarity) * 0.5
return loss_tensor
class IdentityLoss(torch.nn.Module):
def __init__(self) -> None:
super(IdentityLoss, self).__init__()
@@ -148,8 +170,8 @@ class IdentityLoss(torch.nn.Module):
self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.embedder.eval()
def calc_loss(self, source_tensor : Tensor, output_tensor : Tensor) -> Tensor:
def calc(self, source_tensor : Tensor, output_tensor : Tensor) -> Tensor:
output_embedding = calc_embedding(self.embedder, output_tensor, (30, 0, 10, 10))
source_embedding = calc_embedding(self.embedder, source_tensor, (30, 0, 10, 10))
loss = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean()
return loss
loss_tensor = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean()
return loss_tensor
+15 -8
View File
@@ -16,7 +16,7 @@ from .dataset import DynamicDataset
from .helper import calc_embedding
from .models.discriminator import Discriminator
from .models.generator import Generator
from .models.loss import FaceSwapperLoss, IdentityLoss
from .models.loss import FaceSwapperLoss, IdentityLoss, ReconstructionLoss
from .types import Batch, Embedding, VisionTensor
CONFIG = configparser.ConfigParser()
@@ -31,6 +31,7 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
self.generator = Generator()
self.discriminator = Discriminator()
self.reconstruction_loss = ReconstructionLoss()
self.identity_loss = IdentityLoss()
self.automatic_optimization = automatic_optimization
@@ -76,18 +77,24 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
self.log('loss_adversarial', generator_loss_set.get('loss_adversarial'))
self.log('loss_attribute', generator_loss_set.get('loss_attribute'))
self.log('loss_identity', generator_loss_set.get('loss_identity'), prog_bar = True)
self.log('loss_reconstruction', generator_loss_set.get('loss_reconstruction'))
self.log('loss_reconstruction', generator_loss_set.get('loss_reconstruction'), prog_bar = True)
identity_loss = self.identity_loss.calc_loss(generator_output_tensor, source_tensor)
generator_loss = self.calc_generator_loss_new(identity_loss)
reconstruction_loss = self.reconstruction_loss.calc(source_tensor, target_tensor, generator_output_tensor)
identity_loss = self.identity_loss.calc(generator_output_tensor, source_tensor)
generator_loss = self.calc_generator_loss_new(reconstruction_loss, identity_loss)
self.log('loss_generator_new', generator_loss, prog_bar = True)
self.log('generator_loss_new', generator_loss, prog_bar = True)
self.log('loss_reconstruction_new', reconstruction_loss, prog_bar = True)
self.log('loss_identity_new', identity_loss, prog_bar = True)
return generator_loss_set.get('loss_generator')
def calc_generator_loss_new(self, identity_loss : Tensor) -> Tensor:
weight_identity = CONFIG.getfloat('training.losses', 'weight_identity')
generator_loss = identity_loss * weight_identity
@staticmethod
def calc_generator_loss_new(reconstruction_loss : Tensor, identity_loss : Tensor) -> Tensor:
reconstruction_weight = CONFIG.getfloat('training.losses', 'reconstruction_weight')
identity_weight = CONFIG.getfloat('training.losses', 'identity_weight')
generator_loss = reconstruction_loss * reconstruction_weight
generator_loss += identity_loss * identity_weight
return generator_loss