mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Introduce new ReconstructionLoss class
This commit is contained in:
@@ -63,12 +63,12 @@ kernel_size = 4
|
||||
|
||||
```
|
||||
[training.losses]
|
||||
weight_adversarial = 1.5
|
||||
weight_identity = 15
|
||||
weight_attribute = 10
|
||||
weight_reconstruction = 15
|
||||
weight_pose = 0
|
||||
weight_gaze = 0
|
||||
adversarial_weight = 1.5
|
||||
attribute_weight = 10
|
||||
reconstruction_weight = 15
|
||||
identity_weight = 15
|
||||
pose_weight = 0
|
||||
gaze_weight = 0
|
||||
```
|
||||
|
||||
```
|
||||
|
||||
@@ -141,6 +141,28 @@ class FaceSwapperLoss:
|
||||
return translation, scale, rotation
|
||||
|
||||
|
||||
class ReconstructionLoss(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super(ReconstructionLoss, self).__init__()
|
||||
|
||||
def calc(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tensor:
|
||||
batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
|
||||
loss_tensor = torch.pow(output_tensor - target_tensor, 2).reshape(batch_size, -1)
|
||||
loss_tensor = torch.mean(loss_tensor, dim = 1) * 0.5
|
||||
|
||||
if torch.equal(source_tensor, target_tensor):
|
||||
loss_tensor = torch.sum(loss_tensor * torch.tensor(0)) / (torch.tensor(0).sum() + 1e-4)
|
||||
else:
|
||||
loss_tensor = torch.sum(loss_tensor * torch.tensor(1)) / (torch.tensor(1).sum() + 1e-4)
|
||||
|
||||
data_range = float(torch.max(output_tensor) - torch.min(output_tensor))
|
||||
similarity = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean()
|
||||
|
||||
loss_tensor = (loss_tensor + similarity) * 0.5
|
||||
return loss_tensor
|
||||
|
||||
|
||||
class IdentityLoss(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super(IdentityLoss, self).__init__()
|
||||
@@ -148,8 +170,8 @@ class IdentityLoss(torch.nn.Module):
|
||||
self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.embedder.eval()
|
||||
|
||||
def calc_loss(self, source_tensor : Tensor, output_tensor : Tensor) -> Tensor:
|
||||
def calc(self, source_tensor : Tensor, output_tensor : Tensor) -> Tensor:
|
||||
output_embedding = calc_embedding(self.embedder, output_tensor, (30, 0, 10, 10))
|
||||
source_embedding = calc_embedding(self.embedder, source_tensor, (30, 0, 10, 10))
|
||||
loss = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean()
|
||||
return loss
|
||||
loss_tensor = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean()
|
||||
return loss_tensor
|
||||
|
||||
@@ -16,7 +16,7 @@ from .dataset import DynamicDataset
|
||||
from .helper import calc_embedding
|
||||
from .models.discriminator import Discriminator
|
||||
from .models.generator import Generator
|
||||
from .models.loss import FaceSwapperLoss, IdentityLoss
|
||||
from .models.loss import FaceSwapperLoss, IdentityLoss, ReconstructionLoss
|
||||
from .types import Batch, Embedding, VisionTensor
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
@@ -31,6 +31,7 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
|
||||
|
||||
self.generator = Generator()
|
||||
self.discriminator = Discriminator()
|
||||
self.reconstruction_loss = ReconstructionLoss()
|
||||
self.identity_loss = IdentityLoss()
|
||||
self.automatic_optimization = automatic_optimization
|
||||
|
||||
@@ -76,18 +77,24 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
|
||||
self.log('loss_adversarial', generator_loss_set.get('loss_adversarial'))
|
||||
self.log('loss_attribute', generator_loss_set.get('loss_attribute'))
|
||||
self.log('loss_identity', generator_loss_set.get('loss_identity'), prog_bar = True)
|
||||
self.log('loss_reconstruction', generator_loss_set.get('loss_reconstruction'))
|
||||
self.log('loss_reconstruction', generator_loss_set.get('loss_reconstruction'), prog_bar = True)
|
||||
|
||||
identity_loss = self.identity_loss.calc_loss(generator_output_tensor, source_tensor)
|
||||
generator_loss = self.calc_generator_loss_new(identity_loss)
|
||||
reconstruction_loss = self.reconstruction_loss.calc(source_tensor, target_tensor, generator_output_tensor)
|
||||
identity_loss = self.identity_loss.calc(generator_output_tensor, source_tensor)
|
||||
generator_loss = self.calc_generator_loss_new(reconstruction_loss, identity_loss)
|
||||
|
||||
self.log('loss_generator_new', generator_loss, prog_bar = True)
|
||||
self.log('generator_loss_new', generator_loss, prog_bar = True)
|
||||
self.log('loss_reconstruction_new', reconstruction_loss, prog_bar = True)
|
||||
self.log('loss_identity_new', identity_loss, prog_bar = True)
|
||||
return generator_loss_set.get('loss_generator')
|
||||
|
||||
def calc_generator_loss_new(self, identity_loss : Tensor) -> Tensor:
|
||||
weight_identity = CONFIG.getfloat('training.losses', 'weight_identity')
|
||||
generator_loss = identity_loss * weight_identity
|
||||
@staticmethod
|
||||
def calc_generator_loss_new(reconstruction_loss : Tensor, identity_loss : Tensor) -> Tensor:
|
||||
reconstruction_weight = CONFIG.getfloat('training.losses', 'reconstruction_weight')
|
||||
identity_weight = CONFIG.getfloat('training.losses', 'identity_weight')
|
||||
|
||||
generator_loss = reconstruction_loss * reconstruction_weight
|
||||
generator_loss += identity_loss * identity_weight
|
||||
|
||||
return generator_loss
|
||||
|
||||
|
||||
Reference in New Issue
Block a user