mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
some fixes
This commit is contained in:
@@ -30,6 +30,7 @@ weight_identity =
|
||||
weight_attribute =
|
||||
weight_reconstruction =
|
||||
weight_pose =
|
||||
weight_gaze =
|
||||
|
||||
[training.trainer]
|
||||
learning_rate =
|
||||
|
||||
@@ -3,8 +3,8 @@ from typing import List
|
||||
|
||||
from torch import nn
|
||||
|
||||
from face_swapper.src.networks.nld import NLD
|
||||
from face_swapper.src.types import VisionTensor
|
||||
from ..networks.nld import NLD
|
||||
from ..types import VisionTensor
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
@@ -3,9 +3,9 @@ from typing import Tuple
|
||||
|
||||
from torch import nn
|
||||
|
||||
from face_swapper.src.networks.attribute_modulator import AADGenerator
|
||||
from face_swapper.src.networks.unet import UNet
|
||||
from face_swapper.src.types import Embedding, TargetAttributes, VisionTensor
|
||||
from ..networks.attribute_modulator import AADGenerator
|
||||
from ..networks.unet import UNet
|
||||
from ..types import Embedding, TargetAttributes, VisionTensor
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
@@ -5,8 +5,8 @@ import torch
|
||||
from pytorch_msssim import ssim
|
||||
from torch import Tensor, nn
|
||||
|
||||
from face_swapper.src.helper import calc_id_embedding, hinge_fake_loss, hinge_real_loss
|
||||
from face_swapper.src.types import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor
|
||||
from ..helper import calc_id_embedding, hinge_fake_loss, hinge_real_loss
|
||||
from ..types import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
@@ -62,35 +62,37 @@ class FaceSwapperLoss:
|
||||
|
||||
def calc_discriminator_loss(self, real_discriminator_outputs : DiscriminatorOutputs, fake_discriminator_outputs : DiscriminatorOutputs) -> DiscriminatorLossSet:
|
||||
discriminator_loss_set = {}
|
||||
loss_fake = torch.Tensor(0)
|
||||
loss_fakes = []
|
||||
|
||||
for fake_discriminator_output in fake_discriminator_outputs:
|
||||
loss_fake += hinge_fake_loss(fake_discriminator_output[0]).mean()
|
||||
loss_fakes.append(hinge_fake_loss(fake_discriminator_output[0]))
|
||||
|
||||
loss_true = torch.Tensor(0)
|
||||
loss_trues = []
|
||||
|
||||
for true_discriminator_output in real_discriminator_outputs:
|
||||
loss_true += hinge_real_loss(true_discriminator_output[0]).mean()
|
||||
loss_trues.append(hinge_real_loss(true_discriminator_output[0]))
|
||||
|
||||
discriminator_loss_set['loss_discriminator'] = (loss_true.mean() + loss_fake.mean()) * 0.5
|
||||
loss_fake = torch.stack(loss_fakes).mean()
|
||||
loss_true = torch.stack(loss_trues).mean()
|
||||
discriminator_loss_set['loss_discriminator'] = (loss_true + loss_fake) * 0.5
|
||||
return discriminator_loss_set
|
||||
|
||||
def calc_adversarial_loss(self, discriminator_outputs : DiscriminatorOutputs) -> LossTensor:
|
||||
loss_adversarial = torch.Tensor(0)
|
||||
loss_adversarials = []
|
||||
|
||||
for discriminator_output in discriminator_outputs:
|
||||
loss_adversarial += hinge_real_loss(discriminator_output[0])
|
||||
loss_adversarials.append(hinge_real_loss(discriminator_output[0]).mean())
|
||||
|
||||
loss_adversarial = torch.mean(loss_adversarial)
|
||||
loss_adversarial = torch.stack(loss_adversarials).mean()
|
||||
return loss_adversarial
|
||||
|
||||
def calc_attribute_loss(self, target_attributes : TargetAttributes, swap_attributes : SwapAttributes) -> LossTensor:
|
||||
loss_attribute = torch.Tensor(0)
|
||||
loss_attributes = []
|
||||
|
||||
for swap_attribute, target_attribute in zip(swap_attributes, target_attributes):
|
||||
loss_attribute += torch.mean(torch.pow(swap_attribute - target_attribute, 2).reshape(self.batch_size, -1), dim = 1).mean()
|
||||
loss_attributes.append(torch.mean(torch.pow(swap_attribute - target_attribute, 2).reshape(self.batch_size, -1), dim = 1).mean())
|
||||
|
||||
loss_attribute *= 0.5
|
||||
loss_attribute = torch.stack(loss_attributes).mean() * 0.5
|
||||
return loss_attribute
|
||||
|
||||
def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> LossTensor:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from face_swapper.src.types import Embedding, TargetAttributes
|
||||
from ..types import Embedding, TargetAttributes
|
||||
|
||||
|
||||
class AADGenerator(nn.Module):
|
||||
|
||||
@@ -25,6 +25,7 @@ CONFIG.read('config.ini')
|
||||
class FaceSwapperTrain(pytorch_lightning.LightningModule, FaceSwapperLoss):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
FaceSwapperLoss.__init__(self)
|
||||
self.generator = AdaptiveEmbeddingIntegrationNetwork()
|
||||
self.discriminator = Discriminator()
|
||||
self.automatic_optimization = CONFIG.getboolean('training.trainer', 'automatic_optimization')
|
||||
@@ -45,12 +46,12 @@ class FaceSwapperTrain(pytorch_lightning.LightningModule, FaceSwapperLoss):
|
||||
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0))
|
||||
swap_tensor, target_attributes = self.generator(target_tensor, source_embedding)
|
||||
swap_attributes = self.generator.get_attributes(swap_tensor)
|
||||
real_discriminator_outputs = self.discriminator(source_tensor.detach())
|
||||
real_discriminator_outputs = self.discriminator(source_tensor)
|
||||
fake_discriminator_outputs = self.discriminator(swap_tensor.detach())
|
||||
|
||||
generator_losses = self.calc_generator_loss(swap_tensor, target_attributes, swap_attributes, fake_discriminator_outputs, batch)
|
||||
generator_optimizer.zero_grad()
|
||||
self.manual_backward(generator_losses.get('loss_generator'))
|
||||
self.manual_backward(generator_losses.get('loss_generator'), retain_graph = True)
|
||||
generator_optimizer.step()
|
||||
|
||||
discriminator_losses = self.calc_discriminator_loss(real_discriminator_outputs, fake_discriminator_outputs)
|
||||
@@ -114,6 +115,9 @@ def train() -> None:
|
||||
num_workers = CONFIG.getint('training.loader', 'num_workers')
|
||||
output_file_path = CONFIG.get('training.output', 'file_path')
|
||||
|
||||
if not os.path.isfile(output_file_path):
|
||||
output_file_path = None
|
||||
|
||||
dataset = DataLoaderVGG(dataset_path, dataset_image_pattern, dataset_directory_pattern, same_person_probability)
|
||||
data_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
face_swap_model = FaceSwapperTrain()
|
||||
|
||||
Reference in New Issue
Block a user