mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Revert the config dicts
This commit is contained in:
@@ -9,10 +9,7 @@ from ..networks.nld import NLD
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config =\
|
||||
{
|
||||
'num_discriminators': config_parser.getint('training.model.discriminator', 'num_discriminators')
|
||||
}
|
||||
self.config_num_discriminators = config_parser.getint('training.model.discriminator', 'num_discriminators')
|
||||
self.config_parser = config_parser
|
||||
self.avg_pool = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = (1, 1), count_include_pad = False)
|
||||
self.discriminators = self.create_discriminators()
|
||||
@@ -20,7 +17,7 @@ class Discriminator(nn.Module):
|
||||
def create_discriminators(self) -> nn.ModuleList:
|
||||
discriminators = nn.ModuleList()
|
||||
|
||||
for _ in range(self.config.get('num_discriminators')):
|
||||
for _ in range(self.config_num_discriminators):
|
||||
discriminator = NLD(self.config_parser).sequences
|
||||
discriminators.append(discriminator)
|
||||
|
||||
|
||||
@@ -101,28 +101,22 @@ class ReconstructionLoss(nn.Module):
|
||||
class IdentityLoss(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser, embedder : EmbedderModule) -> None:
|
||||
super().__init__()
|
||||
self.config =\
|
||||
{
|
||||
'identity_weight': config_parser.getfloat('training.losses', 'identity_weight')
|
||||
}
|
||||
self.config_identity_weight = config_parser.getfloat('training.losses', 'identity_weight')
|
||||
self.embedder = embedder
|
||||
|
||||
def forward(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
output_embedding = calc_embedding(self.embedder, output_tensor, (30, 0, 10, 10))
|
||||
source_embedding = calc_embedding(self.embedder, source_tensor, (30, 0, 10, 10))
|
||||
identity_loss = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean()
|
||||
weighted_identity_loss = identity_loss * self.config.get('identity_weight')
|
||||
weighted_identity_loss = identity_loss * self.config_identity_weight
|
||||
return identity_loss, weighted_identity_loss
|
||||
|
||||
|
||||
class MotionLoss(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser, motion_extractor : MotionExtractorModule):
|
||||
super().__init__()
|
||||
self.config =\
|
||||
{
|
||||
'pose_weight': config_parser.getfloat('training.losses', 'pose_weight'),
|
||||
'expression_weight': config_parser.getfloat('training.losses', 'expression_weight')
|
||||
}
|
||||
self.config_pose_weight = config_parser.getfloat('training.losses', 'pose_weight')
|
||||
self.expression_weight = config_parser.getfloat('training.losses', 'expression_weight')
|
||||
self.motion_extractor = motion_extractor
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
@@ -141,12 +135,12 @@ class MotionLoss(nn.Module):
|
||||
temp_tensors.append(temp_tensor)
|
||||
|
||||
pose_loss = torch.stack(temp_tensors).mean()
|
||||
weighted_pose_loss = pose_loss * self.config.get('pose_weight')
|
||||
weighted_pose_loss = pose_loss * self.config_pose_weight
|
||||
return pose_loss, weighted_pose_loss
|
||||
|
||||
def calc_expression_loss(self, target_expression : Tensor, output_expression : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
expression_loss = (1 - torch.cosine_similarity(target_expression, output_expression)).mean()
|
||||
weighted_expression_loss = expression_loss * self.config.get('expression_weight')
|
||||
weighted_expression_loss = expression_loss * self.config_expression_weight
|
||||
return expression_loss, weighted_expression_loss
|
||||
|
||||
def get_motions(self, input_tensor : Tensor) -> Tuple[Tuple[Tensor, ...], Tensor]:
|
||||
|
||||
@@ -7,31 +7,28 @@ from torch import Tensor, nn
|
||||
class NLD(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config =\
|
||||
{
|
||||
'input_channels': config_parser.getint('training.model.discriminator', 'input_channels'),
|
||||
'num_filters': config_parser.getint('training.model.discriminator', 'num_filters'),
|
||||
'kernel_size': config_parser.getint('training.model.discriminator', 'kernel_size'),
|
||||
'num_layers': config_parser.getint('training.model.discriminator', 'num_layers')
|
||||
}
|
||||
self.config_input_channels = config_parser.getint('training.model.discriminator', 'input_channels')
|
||||
self.config_num_filters = config_parser.getint('training.model.discriminator', 'num_filters')
|
||||
self.config_kernel_size = config_parser.getint('training.model.discriminator', 'kernel_size')
|
||||
self.config_num_layers = config_parser.getint('training.model.discriminator', 'num_layers')
|
||||
self.layers = self.create_layers()
|
||||
self.sequences = nn.Sequential(*self.layers)
|
||||
|
||||
def create_layers(self) -> nn.ModuleList:
|
||||
padding = math.ceil((self.config.get('kernel_size') - 1) / 2)
|
||||
current_filters = self.config.get('num_filters')
|
||||
padding = math.ceil((self.config_kernel_size - 1) / 2)
|
||||
current_filters = self.config_num_filters
|
||||
layers = nn.ModuleList(
|
||||
[
|
||||
nn.Conv2d(self.config.get('input_channels'), current_filters, kernel_size = self.config.get('kernel_size'), stride = 2, padding = padding),
|
||||
nn.Conv2d(self.config_input_channels, current_filters, kernel_size = self.config_kernel_size, stride = 2, padding = padding),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
])
|
||||
|
||||
for _ in range(1, self.config.get('num_layers')):
|
||||
for _ in range(1, self.config_num_layers):
|
||||
previous_filters = current_filters
|
||||
current_filters = min(current_filters * 2, 512)
|
||||
layers +=\
|
||||
[
|
||||
nn.Conv2d(previous_filters, current_filters, kernel_size = self.config.get('kernel_size'), stride = 2, padding = padding),
|
||||
nn.Conv2d(previous_filters, current_filters, kernel_size = self.config_kernel_size, stride = 2, padding = padding),
|
||||
nn.InstanceNorm2d(current_filters),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
@@ -40,10 +37,10 @@ class NLD(nn.Module):
|
||||
current_filters = min(current_filters * 2, 512)
|
||||
layers +=\
|
||||
[
|
||||
nn.Conv2d(previous_filters, current_filters, kernel_size = self.config.get('kernel_size'), padding = padding),
|
||||
nn.Conv2d(previous_filters, current_filters, kernel_size = self.config_kernel_size, padding = padding),
|
||||
nn.InstanceNorm2d(current_filters),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(current_filters, 1, kernel_size = self.config.get('kernel_size'), padding = padding)
|
||||
nn.Conv2d(current_filters, 1, kernel_size = self.config_kernel_size, padding = padding)
|
||||
]
|
||||
return layers
|
||||
|
||||
|
||||
@@ -8,10 +8,7 @@ from torch import Tensor, nn
|
||||
class UNet(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config =\
|
||||
{
|
||||
'output_size': config_parser.getint('training.model.generator', 'output_size')
|
||||
}
|
||||
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
|
||||
self.down_samples = self.create_down_samples()
|
||||
self.up_samples = self.create_up_samples()
|
||||
|
||||
@@ -25,20 +22,20 @@ class UNet(nn.Module):
|
||||
DownSample(256, 512)
|
||||
])
|
||||
|
||||
if self.config.get('output_size') == 128:
|
||||
if self.config_output_size == 128:
|
||||
down_samples.extend(
|
||||
[
|
||||
DownSample(512, 512)
|
||||
])
|
||||
|
||||
if self.config.get('output_size') == 256:
|
||||
if self.config_output_size == 256:
|
||||
down_samples.extend(
|
||||
[
|
||||
DownSample(512, 1024),
|
||||
DownSample(1024, 1024)
|
||||
])
|
||||
|
||||
if self.config.get('output_size') == 512:
|
||||
if self.config_output_size == 512:
|
||||
down_samples.extend(
|
||||
[
|
||||
DownSample(512, 1024),
|
||||
@@ -51,20 +48,20 @@ class UNet(nn.Module):
|
||||
def create_up_samples(self) -> nn.ModuleList:
|
||||
up_samples = nn.ModuleList()
|
||||
|
||||
if self.config.get('output_size') == 128:
|
||||
if self.config_output_size == 128:
|
||||
up_samples.extend(
|
||||
[
|
||||
UpSample(512, 512)
|
||||
])
|
||||
|
||||
if self.config.get('output_size') == 256:
|
||||
if self.config_output_size == 256:
|
||||
up_samples.extend(
|
||||
[
|
||||
UpSample(1024, 1024),
|
||||
UpSample(2048, 512)
|
||||
])
|
||||
|
||||
if self.config.get('output_size') == 512:
|
||||
if self.config_output_size == 512:
|
||||
up_samples.extend(
|
||||
[
|
||||
UpSample(2048, 2048),
|
||||
|
||||
Reference in New Issue
Block a user