some fixes

This commit is contained in:
harisreedhar
2025-02-15 19:04:43 +05:30
committed by henryruhs
parent 0e148845af
commit 030d912c1b
6 changed files with 28 additions and 21 deletions
+1
View File
@@ -30,6 +30,7 @@ weight_identity =
weight_attribute =
weight_reconstruction =
weight_pose =
weight_gaze =
[training.trainer]
learning_rate =
+2 -2
View File
@@ -3,8 +3,8 @@ from typing import List
from torch import nn
from face_swapper.src.networks.nld import NLD
from face_swapper.src.types import VisionTensor
from ..networks.nld import NLD
from ..types import VisionTensor
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
+3 -3
View File
@@ -3,9 +3,9 @@ from typing import Tuple
from torch import nn
from face_swapper.src.networks.attribute_modulator import AADGenerator
from face_swapper.src.networks.unet import UNet
from face_swapper.src.types import Embedding, TargetAttributes, VisionTensor
from ..networks.attribute_modulator import AADGenerator
from ..networks.unet import UNet
from ..types import Embedding, TargetAttributes, VisionTensor
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
+15 -13
View File
@@ -5,8 +5,8 @@ import torch
from pytorch_msssim import ssim
from torch import Tensor, nn
from face_swapper.src.helper import calc_id_embedding, hinge_fake_loss, hinge_real_loss
from face_swapper.src.types import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor
from ..helper import calc_id_embedding, hinge_fake_loss, hinge_real_loss
from ..types import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
@@ -62,35 +62,37 @@ class FaceSwapperLoss:
def calc_discriminator_loss(self, real_discriminator_outputs : DiscriminatorOutputs, fake_discriminator_outputs : DiscriminatorOutputs) -> DiscriminatorLossSet:
discriminator_loss_set = {}
loss_fake = torch.Tensor(0)
loss_fakes = []
for fake_discriminator_output in fake_discriminator_outputs:
loss_fake += hinge_fake_loss(fake_discriminator_output[0]).mean()
loss_fakes.append(hinge_fake_loss(fake_discriminator_output[0]))
loss_true = torch.Tensor(0)
loss_trues = []
for true_discriminator_output in real_discriminator_outputs:
loss_true += hinge_real_loss(true_discriminator_output[0]).mean()
loss_trues.append(hinge_real_loss(true_discriminator_output[0]))
discriminator_loss_set['loss_discriminator'] = (loss_true.mean() + loss_fake.mean()) * 0.5
loss_fake = torch.stack(loss_fakes).mean()
loss_true = torch.stack(loss_trues).mean()
discriminator_loss_set['loss_discriminator'] = (loss_true + loss_fake) * 0.5
return discriminator_loss_set
def calc_adversarial_loss(self, discriminator_outputs : DiscriminatorOutputs) -> LossTensor:
loss_adversarial = torch.Tensor(0)
loss_adversarials = []
for discriminator_output in discriminator_outputs:
loss_adversarial += hinge_real_loss(discriminator_output[0])
loss_adversarials.append(hinge_real_loss(discriminator_output[0]).mean())
loss_adversarial = torch.mean(loss_adversarial)
loss_adversarial = torch.stack(loss_adversarials).mean()
return loss_adversarial
def calc_attribute_loss(self, target_attributes : TargetAttributes, swap_attributes : SwapAttributes) -> LossTensor:
loss_attribute = torch.Tensor(0)
loss_attributes = []
for swap_attribute, target_attribute in zip(swap_attributes, target_attributes):
loss_attribute += torch.mean(torch.pow(swap_attribute - target_attribute, 2).reshape(self.batch_size, -1), dim = 1).mean()
loss_attributes.append(torch.mean(torch.pow(swap_attribute - target_attribute, 2).reshape(self.batch_size, -1), dim = 1).mean())
loss_attribute *= 0.5
loss_attribute = torch.stack(loss_attributes).mean() * 0.5
return loss_attribute
def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> LossTensor:
@@ -1,7 +1,7 @@
import torch
from torch import Tensor, nn
from face_swapper.src.types import Embedding, TargetAttributes
from ..types import Embedding, TargetAttributes
class AADGenerator(nn.Module):
+6 -2
View File
@@ -25,6 +25,7 @@ CONFIG.read('config.ini')
class FaceSwapperTrain(pytorch_lightning.LightningModule, FaceSwapperLoss):
def __init__(self) -> None:
super().__init__()
FaceSwapperLoss.__init__(self)
self.generator = AdaptiveEmbeddingIntegrationNetwork()
self.discriminator = Discriminator()
self.automatic_optimization = CONFIG.getboolean('training.trainer', 'automatic_optimization')
@@ -45,12 +46,12 @@ class FaceSwapperTrain(pytorch_lightning.LightningModule, FaceSwapperLoss):
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0))
swap_tensor, target_attributes = self.generator(target_tensor, source_embedding)
swap_attributes = self.generator.get_attributes(swap_tensor)
real_discriminator_outputs = self.discriminator(source_tensor.detach())
real_discriminator_outputs = self.discriminator(source_tensor)
fake_discriminator_outputs = self.discriminator(swap_tensor.detach())
generator_losses = self.calc_generator_loss(swap_tensor, target_attributes, swap_attributes, fake_discriminator_outputs, batch)
generator_optimizer.zero_grad()
self.manual_backward(generator_losses.get('loss_generator'))
self.manual_backward(generator_losses.get('loss_generator'), retain_graph = True)
generator_optimizer.step()
discriminator_losses = self.calc_discriminator_loss(real_discriminator_outputs, fake_discriminator_outputs)
@@ -114,6 +115,9 @@ def train() -> None:
num_workers = CONFIG.getint('training.loader', 'num_workers')
output_file_path = CONFIG.get('training.output', 'file_path')
if not os.path.isfile(output_file_path):
output_file_path = None
dataset = DataLoaderVGG(dataset_path, dataset_image_pattern, dataset_directory_pattern, same_person_probability)
data_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
face_swap_model = FaceSwapperTrain()