mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Introduce new DiscriminatorLoss class, Remove useless super call params
This commit is contained in:
@@ -4,7 +4,7 @@ from torch import Tensor, nn
|
||||
|
||||
class EmbeddingConverter(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super(EmbeddingConverter, self).__init__()
|
||||
super().__init__()
|
||||
self.layers = self.create_layers()
|
||||
self.leaky_relu = nn.LeakyReLU()
|
||||
|
||||
|
||||
@@ -25,12 +25,12 @@ num_discriminators =
|
||||
kernel_size =
|
||||
|
||||
[training.losses]
|
||||
weight_adversarial =
|
||||
weight_identity =
|
||||
weight_attribute =
|
||||
weight_reconstruction =
|
||||
weight_pose =
|
||||
weight_gaze =
|
||||
adversarial_weight =
|
||||
attribute_weight =
|
||||
reconstruction_weight =
|
||||
identity_weight =
|
||||
pose_weight =
|
||||
gaze_weight =
|
||||
|
||||
[training.trainer]
|
||||
learning_rate =
|
||||
|
||||
@@ -22,18 +22,6 @@ def convert_to_vision_frame(vision_tensor : VisionTensor) -> VisionFrame:
|
||||
return vision_frame
|
||||
|
||||
|
||||
def hinge_real_loss(input_tensor : Tensor) -> Tensor:
|
||||
real_loss = torch.relu(1 - input_tensor)
|
||||
real_loss = real_loss.mean(dim = [ 1, 2, 3 ])
|
||||
return real_loss
|
||||
|
||||
|
||||
def hinge_fake_loss(input_tensor : Tensor) -> Tensor:
|
||||
fake_loss = torch.relu(input_tensor + 1)
|
||||
fake_loss = fake_loss.mean(dim = [ 1, 2, 3 ])
|
||||
return fake_loss
|
||||
|
||||
|
||||
def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : Padding) -> Embedding:
|
||||
crop_tensor = input_tensor[:, :, 15: 241, 15: 241]
|
||||
crop_tensor = nn.functional.interpolate(crop_tensor, size = (112, 112), mode = 'area')
|
||||
|
||||
@@ -11,7 +11,7 @@ CONFIG.read('config.ini')
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super(Discriminator, self).__init__()
|
||||
super().__init__()
|
||||
self.avg_pool = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = (1, 1), count_include_pad = False)
|
||||
self.discriminators = self.create_discriminators()
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ CONFIG.read('config.ini')
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super(Generator, self).__init__()
|
||||
super().__init__()
|
||||
encoder_type = CONFIG.get('training.model.generator', 'encoder_type')
|
||||
id_channels = CONFIG.getint('training.model.generator', 'id_channels')
|
||||
num_blocks = CONFIG.getint('training.model.generator', 'num_blocks')
|
||||
|
||||
@@ -5,13 +5,25 @@ import torch
|
||||
from pytorch_msssim import ssim
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..helper import calc_embedding, hinge_fake_loss, hinge_real_loss
|
||||
from ..helper import calc_embedding
|
||||
from ..types import Attributes, Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
|
||||
def hinge_real_loss(input_tensor : Tensor) -> Tensor:
|
||||
real_loss = torch.relu(1 - input_tensor)
|
||||
real_loss = real_loss.mean(dim = [ 1, 2, 3 ])
|
||||
return real_loss
|
||||
|
||||
|
||||
def hinge_fake_loss(input_tensor : Tensor) -> Tensor:
|
||||
fake_loss = torch.relu(input_tensor + 1)
|
||||
fake_loss = fake_loss.mean(dim = [ 1, 2, 3 ])
|
||||
return fake_loss
|
||||
|
||||
|
||||
class FaceSwapperLoss:
|
||||
def __init__(self) -> None:
|
||||
embedder_path = CONFIG.get('training.model', 'embedder_path')
|
||||
@@ -51,6 +63,16 @@ class FaceSwapperLoss:
|
||||
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_gaze') * weight_gaze
|
||||
return generator_loss_set
|
||||
|
||||
def hinge_real_loss(input_tensor: Tensor) -> Tensor:
|
||||
real_loss = torch.relu(1 - input_tensor)
|
||||
real_loss = real_loss.mean(dim = [1, 2, 3])
|
||||
return real_loss
|
||||
|
||||
def hinge_fake_loss(input_tensor: Tensor) -> Tensor:
|
||||
fake_loss = torch.relu(input_tensor + 1)
|
||||
fake_loss = fake_loss.mean(dim = [1, 2, 3])
|
||||
return fake_loss
|
||||
|
||||
def calc_discriminator_loss(self, real_discriminator_outputs : DiscriminatorOutputs, fake_discriminator_outputs : DiscriminatorOutputs) -> DiscriminatorLossSet:
|
||||
discriminator_loss_set = {}
|
||||
loss_fakes = []
|
||||
@@ -131,9 +153,31 @@ class FaceSwapperLoss:
|
||||
return translation, scale, rotation
|
||||
|
||||
|
||||
class AdversarialLoss(torch.nn.Module):
|
||||
class DiscriminatorLoss(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super(AdversarialLoss, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
def calc(self, discriminator_source_tensors : List[Tensor], discriminator_output_tensors : List[Tensor]) -> Tensor:
|
||||
temp1_tensors = []
|
||||
temp2_tensors = []
|
||||
|
||||
for discriminator_output_tensor in discriminator_output_tensors:
|
||||
temp1_tensor = torch.relu(discriminator_output_tensor[0] + 1).mean(dim = [ 1, 2, 3 ])
|
||||
temp1_tensors.append(temp1_tensor)
|
||||
|
||||
for discriminator_source_tensor in discriminator_source_tensors:
|
||||
temp2_tensor = torch.relu(1 - discriminator_source_tensor[0]).mean(dim = [ 1, 2, 3 ])
|
||||
temp2_tensors.append(temp2_tensor)
|
||||
|
||||
discriminator1_loss = torch.stack(temp1_tensors).mean()
|
||||
discriminator2_loss = torch.stack(temp2_tensors).mean()
|
||||
discriminator_loss = (discriminator1_loss + discriminator2_loss) * 0.5
|
||||
return discriminator_loss
|
||||
|
||||
|
||||
class AdversarialLoss(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def calc(self, discriminator_output_tensors : List[Tensor]) -> Tuple[Tensor, Tensor]:
|
||||
adversarial_weight = CONFIG.getfloat('training.losses', 'adversarial_weight')
|
||||
@@ -148,9 +192,9 @@ class AdversarialLoss(torch.nn.Module):
|
||||
return adversarial_loss, weighted_adversarial_loss
|
||||
|
||||
|
||||
class AttributeLoss(torch.nn.Module):
|
||||
class AttributeLoss(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super(AttributeLoss, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
def calc(self, target_attributes : Attributes, output_attributes : Attributes) -> Tuple[Tensor, Tensor]:
|
||||
batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
@@ -166,9 +210,9 @@ class AttributeLoss(torch.nn.Module):
|
||||
return attribute_loss, weighted_attribute_loss
|
||||
|
||||
|
||||
class ReconstructionLoss(torch.nn.Module):
|
||||
class ReconstructionLoss(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super(ReconstructionLoss, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
def calc(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
@@ -189,9 +233,9 @@ class ReconstructionLoss(torch.nn.Module):
|
||||
return reconstruction_loss, weighted_reconstruction_loss
|
||||
|
||||
|
||||
class IdentityLoss(torch.nn.Module):
|
||||
class IdentityLoss(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super(IdentityLoss, self).__init__()
|
||||
super().__init__()
|
||||
embedder_path = CONFIG.get('training.model', 'embedder_path')
|
||||
self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
|
||||
@@ -204,9 +248,9 @@ class IdentityLoss(torch.nn.Module):
|
||||
return identity_loss, weighted_identity_loss
|
||||
|
||||
|
||||
class PoseLoss(torch.nn.Module):
|
||||
class PoseLoss(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super(PoseLoss, self).__init__()
|
||||
super().__init__()
|
||||
motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path')
|
||||
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.mse_loss = nn.MSELoss()
|
||||
@@ -232,9 +276,9 @@ class PoseLoss(torch.nn.Module):
|
||||
return translation, scale, rotation
|
||||
|
||||
|
||||
class GazeLoss(torch.nn.Module):
|
||||
class GazeLoss(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super(GazeLoss, self).__init__()
|
||||
super().__init__()
|
||||
landmarker_path = CONFIG.get('training.model', 'landmarker_path')
|
||||
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
@@ -6,7 +6,7 @@ from ..types import Embedding, TargetAttributes
|
||||
|
||||
class AADGenerator(nn.Module):
|
||||
def __init__(self, id_channels : int, num_blocks : int) -> None:
|
||||
super(AADGenerator, self).__init__()
|
||||
super().__init__()
|
||||
self.upsample = PixelShuffleUpsample(id_channels, 1024 * 4)
|
||||
self.res_block_1 = AADResBlock(1024, 1024, 1024, id_channels, num_blocks)
|
||||
self.res_block_2 = AADResBlock(1024, 1024, 2048, id_channels, num_blocks)
|
||||
@@ -56,7 +56,7 @@ class AADLayer(nn.Module):
|
||||
|
||||
class AADSequential(nn.Module):
|
||||
def __init__(self, *args : nn.Module) -> None:
|
||||
super(AADSequential, self).__init__()
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(args)
|
||||
|
||||
def forward(self, feature_map : Tensor, attribute_embedding : Embedding, id_embedding : Embedding) -> Tensor:
|
||||
@@ -70,7 +70,7 @@ class AADSequential(nn.Module):
|
||||
|
||||
class AADResBlock(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int, attribute_channels : int, id_channels : int, num_blocks : int) -> None:
|
||||
super(AADResBlock, self).__init__()
|
||||
super().__init__()
|
||||
self.input_channels = input_channels
|
||||
self.output_channels = output_channels
|
||||
self.prepare_primary_add_blocks(input_channels, attribute_channels, id_channels, output_channels, num_blocks)
|
||||
@@ -111,7 +111,7 @@ class AADResBlock(nn.Module):
|
||||
|
||||
class PixelShuffleUpsample(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int) -> None:
|
||||
super(PixelShuffleUpsample, self).__init__()
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 3, padding = 1)
|
||||
self.pixel_shuffle = nn.PixelShuffle(upscale_factor = 2)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from torch import Tensor, nn
|
||||
|
||||
class NLD(nn.Module):
|
||||
def __init__(self, input_channels : int, num_filters : int, num_layers : int, kernel_size : int) -> None:
|
||||
super(NLD, self).__init__()
|
||||
super().__init__()
|
||||
self.nld = self.create_nld(input_channels, num_filters, num_layers, kernel_size)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -7,7 +7,7 @@ from torchvision import models
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super(UNet, self).__init__()
|
||||
super().__init__()
|
||||
self.down_samples = self.create_down_samples(self)
|
||||
self.up_samples = self.create_up_samples()
|
||||
|
||||
@@ -87,7 +87,7 @@ class UNetPro(UNet):
|
||||
|
||||
class UpSample(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int) -> None:
|
||||
super(UpSample, self).__init__()
|
||||
super().__init__()
|
||||
self.conv_transpose = nn.ConvTranspose2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
|
||||
self.batch_norm = nn.BatchNorm2d(output_channels)
|
||||
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
|
||||
@@ -102,7 +102,7 @@ class UpSample(nn.Module):
|
||||
|
||||
class DownSample(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int) -> None:
|
||||
super(DownSample, self).__init__()
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
|
||||
self.batch_norm = nn.BatchNorm2d(output_channels)
|
||||
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
|
||||
|
||||
@@ -16,7 +16,7 @@ from .dataset import DynamicDataset
|
||||
from .helper import calc_embedding
|
||||
from .models.discriminator import Discriminator
|
||||
from .models.generator import Generator
|
||||
from .models.loss import AdversarialLoss, AttributeLoss, FaceSwapperLoss, GazeLoss, IdentityLoss, PoseLoss, ReconstructionLoss
|
||||
from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, FaceSwapperLoss, GazeLoss, IdentityLoss, PoseLoss, ReconstructionLoss
|
||||
from .types import Batch, Embedding, VisionTensor
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
@@ -31,6 +31,7 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
|
||||
|
||||
self.generator = Generator()
|
||||
self.discriminator = Discriminator()
|
||||
self.discriminator_loss = DiscriminatorLoss()
|
||||
self.adversarial_loss = AdversarialLoss()
|
||||
self.attribute_loss = AttributeLoss()
|
||||
self.reconstruction_loss = ReconstructionLoss()
|
||||
@@ -65,10 +66,10 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
|
||||
self.manual_backward(generator_loss_set.get('loss_generator'))
|
||||
generator_optimizer.step()
|
||||
|
||||
discriminator_source_tensor = self.discriminator(source_tensor)
|
||||
discriminator_source_tensors = self.discriminator(source_tensor)
|
||||
discriminator_output_tensors = self.discriminator(generator_output_tensor.detach())
|
||||
|
||||
discriminator_loss_set = self.calc_discriminator_loss(discriminator_source_tensor, discriminator_output_tensors)
|
||||
discriminator_loss_set = self.calc_discriminator_loss(discriminator_source_tensors, discriminator_output_tensors)
|
||||
discriminator_optimizer.zero_grad()
|
||||
self.manual_backward(discriminator_loss_set.get('loss_discriminator'))
|
||||
discriminator_optimizer.step()
|
||||
@@ -77,16 +78,17 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
|
||||
self.generate_preview(source_tensor, target_tensor, generator_output_tensor)
|
||||
|
||||
self.log('loss_generator', generator_loss_set.get('loss_generator'), prog_bar = True)
|
||||
self.log('loss_discriminator', discriminator_loss_set.get('loss_discriminator'))
|
||||
self.log('loss_discriminator', discriminator_loss_set.get('loss_discriminator'), prog_bar = True)
|
||||
self.log('loss_adversarial', generator_loss_set.get('loss_adversarial'))
|
||||
self.log('loss_attribute', generator_loss_set.get('loss_attribute'))
|
||||
self.log('loss_identity', generator_loss_set.get('loss_identity'))
|
||||
self.log('loss_reconstruction', generator_loss_set.get('loss_reconstruction'))
|
||||
self.log('loss_pose', generator_loss_set.get('loss_pose'))
|
||||
self.log('loss_gaze', generator_loss_set.get('loss_gaze'), prog_bar = True)
|
||||
self.log('loss_gaze', generator_loss_set.get('loss_gaze'))
|
||||
|
||||
###############################################
|
||||
|
||||
discriminator_loss = self.discriminator_loss.calc(discriminator_source_tensors, discriminator_output_tensors)
|
||||
adversarial_loss, weighted_adversarial_loss = self.adversarial_loss.calc(discriminator_output_tensors)
|
||||
attribute_loss, weighted_attribute_loss = self.attribute_loss.calc(target_attributes, generator_output_attributes)
|
||||
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss.calc(source_tensor, target_tensor, generator_output_tensor)
|
||||
@@ -96,12 +98,13 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
|
||||
generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss
|
||||
|
||||
self.log('generator_loss_new', generator_loss, prog_bar = True)
|
||||
self.log('discriminator_loss_new', discriminator_loss, prog_bar = True)
|
||||
self.log('adversarial_loss_new', adversarial_loss)
|
||||
self.log('attribute_loss_new', attribute_loss)
|
||||
self.log('reconstruction_loss_new', reconstruction_loss)
|
||||
self.log('identity_loss_new', identity_loss)
|
||||
self.log('pose_loss_new', pose_loss)
|
||||
self.log('gaze_loss_new', gaze_loss, prog_bar = True)
|
||||
self.log('gaze_loss_new', gaze_loss)
|
||||
return generator_loss_set.get('loss_generator')
|
||||
|
||||
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
@@ -161,7 +164,7 @@ def create_trainer() -> Trainer:
|
||||
callbacks =
|
||||
[
|
||||
ModelCheckpoint(
|
||||
monitor = 'loss_generator',
|
||||
monitor = 'generator_loss',
|
||||
dirpath = output_directory_path,
|
||||
filename = output_file_pattern,
|
||||
every_n_train_steps = 1000,
|
||||
|
||||
Reference in New Issue
Block a user