diff --git a/face_swapper/config.ini b/face_swapper/config.ini index 01e343f..92ee043 100644 --- a/face_swapper/config.ini +++ b/face_swapper/config.ini @@ -30,6 +30,7 @@ weight_identity = weight_attribute = weight_reconstruction = weight_pose = +weight_gaze = [training.trainer] learning_rate = diff --git a/face_swapper/src/models/discriminator.py b/face_swapper/src/models/discriminator.py index ea7b814..c8fc1c4 100644 --- a/face_swapper/src/models/discriminator.py +++ b/face_swapper/src/models/discriminator.py @@ -3,8 +3,8 @@ from typing import List from torch import nn -from face_swapper.src.networks.nld import NLD -from face_swapper.src.types import VisionTensor +from ..networks.nld import NLD +from ..types import VisionTensor CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') diff --git a/face_swapper/src/models/generator.py b/face_swapper/src/models/generator.py index e9b25e0..33d9c80 100644 --- a/face_swapper/src/models/generator.py +++ b/face_swapper/src/models/generator.py @@ -3,9 +3,9 @@ from typing import Tuple from torch import nn -from face_swapper.src.networks.attribute_modulator import AADGenerator -from face_swapper.src.networks.unet import UNet -from face_swapper.src.types import Embedding, TargetAttributes, VisionTensor +from ..networks.attribute_modulator import AADGenerator +from ..networks.unet import UNet +from ..types import Embedding, TargetAttributes, VisionTensor CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 35d5152..02196c7 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -5,8 +5,8 @@ import torch from pytorch_msssim import ssim from torch import Tensor, nn -from face_swapper.src.helper import calc_id_embedding, hinge_fake_loss, hinge_real_loss -from face_swapper.src.types import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor +from ..helper import calc_id_embedding, hinge_fake_loss, hinge_real_loss +from ..types import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') @@ -62,35 +62,37 @@ class FaceSwapperLoss: def calc_discriminator_loss(self, real_discriminator_outputs : DiscriminatorOutputs, fake_discriminator_outputs : DiscriminatorOutputs) -> DiscriminatorLossSet: discriminator_loss_set = {} - loss_fake = torch.Tensor(0) + loss_fakes = [] for fake_discriminator_output in fake_discriminator_outputs: - loss_fake += hinge_fake_loss(fake_discriminator_output[0]).mean() + loss_fakes.append(hinge_fake_loss(fake_discriminator_output[0])) - loss_true = torch.Tensor(0) + loss_trues = [] for true_discriminator_output in real_discriminator_outputs: - loss_true += hinge_real_loss(true_discriminator_output[0]).mean() + loss_trues.append(hinge_real_loss(true_discriminator_output[0])) - discriminator_loss_set['loss_discriminator'] = (loss_true.mean() + loss_fake.mean()) * 0.5 + loss_fake = torch.stack(loss_fakes).mean() + loss_true = torch.stack(loss_trues).mean() + discriminator_loss_set['loss_discriminator'] = (loss_true + loss_fake) * 0.5 return discriminator_loss_set def calc_adversarial_loss(self, discriminator_outputs : DiscriminatorOutputs) -> LossTensor: - loss_adversarial = torch.Tensor(0) + loss_adversarials = [] for discriminator_output in discriminator_outputs: - loss_adversarial += hinge_real_loss(discriminator_output[0]) + loss_adversarials.append(hinge_real_loss(discriminator_output[0]).mean()) - loss_adversarial = torch.mean(loss_adversarial) + loss_adversarial = torch.stack(loss_adversarials).mean() return loss_adversarial def calc_attribute_loss(self, target_attributes : TargetAttributes, swap_attributes : SwapAttributes) -> LossTensor: - loss_attribute = torch.Tensor(0) + loss_attributes = [] for swap_attribute, target_attribute in zip(swap_attributes, target_attributes): - loss_attribute += torch.mean(torch.pow(swap_attribute - target_attribute, 2).reshape(self.batch_size, -1), dim = 1).mean() + loss_attributes.append(torch.mean(torch.pow(swap_attribute - target_attribute, 2).reshape(self.batch_size, -1), dim = 1).mean()) - loss_attribute *= 0.5 + loss_attribute = torch.stack(loss_attributes).mean() * 0.5 return loss_attribute def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> LossTensor: diff --git a/face_swapper/src/networks/attribute_modulator.py b/face_swapper/src/networks/attribute_modulator.py index 1dae96c..4d3e973 100644 --- a/face_swapper/src/networks/attribute_modulator.py +++ b/face_swapper/src/networks/attribute_modulator.py @@ -1,7 +1,7 @@ import torch from torch import Tensor, nn -from face_swapper.src.types import Embedding, TargetAttributes +from ..types import Embedding, TargetAttributes class AADGenerator(nn.Module): diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index a68dc7b..629048b 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -25,6 +25,7 @@ CONFIG.read('config.ini') class FaceSwapperTrain(pytorch_lightning.LightningModule, FaceSwapperLoss): def __init__(self) -> None: super().__init__() + FaceSwapperLoss.__init__(self) self.generator = AdaptiveEmbeddingIntegrationNetwork() self.discriminator = Discriminator() self.automatic_optimization = CONFIG.getboolean('training.trainer', 'automatic_optimization') @@ -45,12 +46,12 @@ class FaceSwapperTrain(pytorch_lightning.LightningModule, FaceSwapperLoss): source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0)) swap_tensor, target_attributes = self.generator(target_tensor, source_embedding) swap_attributes = self.generator.get_attributes(swap_tensor) - real_discriminator_outputs = self.discriminator(source_tensor.detach()) + real_discriminator_outputs = self.discriminator(source_tensor) fake_discriminator_outputs = self.discriminator(swap_tensor.detach()) generator_losses = self.calc_generator_loss(swap_tensor, target_attributes, swap_attributes, fake_discriminator_outputs, batch) generator_optimizer.zero_grad() - self.manual_backward(generator_losses.get('loss_generator')) + self.manual_backward(generator_losses.get('loss_generator'), retain_graph = True) generator_optimizer.step() discriminator_losses = self.calc_discriminator_loss(real_discriminator_outputs, fake_discriminator_outputs) @@ -114,6 +115,9 @@ def train() -> None: num_workers = CONFIG.getint('training.loader', 'num_workers') output_file_path = CONFIG.get('training.output', 'file_path') + if not os.path.isfile(output_file_path): + output_file_path = None + dataset = DataLoaderVGG(dataset_path, dataset_image_pattern, dataset_directory_pattern, same_person_probability) data_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True) face_swap_model = FaceSwapperTrain()