mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Revert the config dicts
This commit is contained in:
@@ -17,8 +17,7 @@ def export() -> None:
|
||||
config_opset_version = CONFIG_PARSER.getint('exporting', 'opset_version')
|
||||
|
||||
os.makedirs(config_directory_path, exist_ok = True)
|
||||
model = EmbeddingConverterTrainer.load_from_checkpoint(config_source_path, map_location = 'cpu')
|
||||
model.eval()
|
||||
model = EmbeddingConverterTrainer.load_from_checkpoint(config_source_path, map_location = 'cpu').eval()
|
||||
model.ir_version = torch.tensor(config_ir_version)
|
||||
input_tensor = torch.randn(1, 512)
|
||||
torch.onnx.export(model, input_tensor, config_target_path, input_names = [ 'input' ], output_names = [ 'output' ], opset_version = config_opset_version)
|
||||
|
||||
@@ -25,8 +25,8 @@ class EmbeddingConverterTrainer(LightningModule):
|
||||
self.config_target_path = config_parser.get('training.model', 'target_path')
|
||||
self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate')
|
||||
self.embedding_converter = EmbeddingConverter()
|
||||
self.source_embedder = torch.jit.load(self.config_source_path, map_location = 'cpu')
|
||||
self.target_embedder = torch.jit.load(self.config_target_path, map_location = 'cpu')
|
||||
self.source_embedder = torch.jit.load(self.config_source_path, map_location = 'cpu').eval()
|
||||
self.target_embedder = torch.jit.load(self.config_target_path, map_location = 'cpu').eval()
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def forward(self, source_embedding : Embedding) -> Embedding:
|
||||
|
||||
+14
-20
@@ -15,27 +15,24 @@ from .types import Batch, BatchMode, WarpTemplate
|
||||
|
||||
class DynamicDataset(Dataset[Tensor]):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
self.config =\
|
||||
{
|
||||
'file_pattern': config_parser.get('training.dataset', 'file_pattern'),
|
||||
'transform_size': config_parser.get('training.dataset', 'transform_size'),
|
||||
'batch_mode': cast(BatchMode, config_parser.get('training.dataset', 'batch_mode')),
|
||||
'batch_ratio': config_parser.getfloat('training.dataset', 'batch_ratio'),
|
||||
}
|
||||
self.config_file_pattern = config_parser.get('training.dataset', 'file_pattern')
|
||||
self.config_transform_size = config_parser.get('training.dataset', 'transform_size')
|
||||
self.config_batch_mode = cast(BatchMode, config_parser.get('training.dataset', 'batch_mode'))
|
||||
self.config_batch_ratio = config_parser.getfloat('training.dataset', 'batch_ratio')
|
||||
self.config_parser = config_parser
|
||||
self.file_paths = glob.glob(self.config.get('file_pattern')) # type:ignore[type-var]
|
||||
self.file_paths = glob.glob(self.config_file_pattern)
|
||||
self.transforms = self.compose_transforms()
|
||||
|
||||
def __getitem__(self, index : int) -> Batch:
|
||||
file_path = self.file_paths[index]
|
||||
|
||||
if random.random() < self.config.get('batch_ratio'): # type:ignore[operator]
|
||||
if self.config.get('batch_mode') == 'equal':
|
||||
return self.prepare_equal_batch(file_path) # type:ignore[arg-type]
|
||||
if self.config.get('batch_mode') == 'same':
|
||||
return self.prepare_same_batch(file_path) # type:ignore[arg-type]
|
||||
if random.random() < self.config_batch_ratio:
|
||||
if self.config_batch_mode == 'equal':
|
||||
return self.prepare_equal_batch(file_path)
|
||||
if self.config_batch_mode == 'same':
|
||||
return self.prepare_same_batch(file_path)
|
||||
|
||||
return self.prepare_different_batch(file_path) # type:ignore[arg-type]
|
||||
return self.prepare_different_batch(file_path)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.file_paths)
|
||||
@@ -45,7 +42,7 @@ class DynamicDataset(Dataset[Tensor]):
|
||||
[
|
||||
AugmentTransform(),
|
||||
transforms.ToPILImage(),
|
||||
transforms.Resize((self.config.get('transform_size'), self.config.get('transform_size')), interpolation = transforms.InterpolationMode.BICUBIC),
|
||||
transforms.Resize((self.config_transform_size, self.config_transform_size), interpolation = transforms.InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
WarpTransform(self.config_parser),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
@@ -99,11 +96,8 @@ class AugmentTransform:
|
||||
|
||||
class WarpTransform:
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
self.config =\
|
||||
{
|
||||
'warp_template': cast(WarpTemplate, config_parser.get('training.dataset', 'warp_template'))
|
||||
}
|
||||
self.config_warp_template = cast(WarpTemplate, config_parser.get('training.dataset', 'warp_template'))
|
||||
|
||||
def __call__(self, input_tensor : Tensor) -> Tensor:
|
||||
temp_tensor = input_tensor.unsqueeze(0)
|
||||
return warp_tensor(temp_tensor, self.config.get('warp_template')).squeeze(0)
|
||||
return warp_tensor(temp_tensor, self.config_warp_template).squeeze(0)
|
||||
|
||||
@@ -11,20 +11,16 @@ CONFIG_PARSER.read('config.ini')
|
||||
|
||||
|
||||
def export() -> None:
|
||||
config =\
|
||||
{
|
||||
'directory_path': CONFIG_PARSER.get('exporting', 'directory_path'),
|
||||
'source_path': CONFIG_PARSER.get('exporting', 'source_path'),
|
||||
'target_path': CONFIG_PARSER.get('exporting', 'target_path'),
|
||||
'target_size': CONFIG_PARSER.getint('exporting', 'target_size'),
|
||||
'ir_version': CONFIG_PARSER.getint('exporting', 'ir_version'),
|
||||
'opset_version': CONFIG_PARSER.getint('exporting', 'opset_version')
|
||||
}
|
||||
config_directory_path = CONFIG_PARSER.get('exporting', 'directory_path')
|
||||
config_source_path = CONFIG_PARSER.get('exporting', 'source_path')
|
||||
config_target_path = CONFIG_PARSER.get('exporting', 'target_path')
|
||||
config_target_size = CONFIG_PARSER.getint('exporting', 'target_size')
|
||||
config_ir_version = CONFIG_PARSER.getint('exporting', 'ir_version')
|
||||
config_opset_version = CONFIG_PARSER.getint('exporting', 'opset_version')
|
||||
|
||||
os.makedirs(config.get('directory_path'), exist_ok = True) # type:ignore[arg-type]
|
||||
model = FaceSwapperTrainer.load_from_checkpoint(config.get('source_path'), map_location = 'cpu')
|
||||
model.eval()
|
||||
model.ir_version = torch.tensor(config.get('ir_version'))
|
||||
os.makedirs(config_directory_path, exist_ok = True)
|
||||
model = FaceSwapperTrainer.load_from_checkpoint(config_source_path, map_location = 'cpu').eval()
|
||||
model.ir_version = torch.tensor(config_ir_version)
|
||||
source_tensor = torch.randn(1, 512)
|
||||
target_tensor = torch.randn(1, 3, config.get('target_size'), config.get('target_size'))
|
||||
torch.onnx.export(model, (source_tensor, target_tensor), config.get('target_path'), input_names = [ 'source', 'target' ], output_names = [ 'output' ], opset_version = config.get('opset_version'))
|
||||
target_tensor = torch.randn(1, 3, config_target_size, config_target_size)
|
||||
torch.onnx.export(model, (source_tensor, target_tensor), config_target_path, input_names = [ 'source', 'target' ], output_names = [ 'output' ], opset_version = config_opset_version)
|
||||
|
||||
@@ -11,24 +11,20 @@ CONFIG_PARSER.read('config.ini')
|
||||
|
||||
|
||||
def infer() -> None:
|
||||
config =\
|
||||
{
|
||||
'generator_path': CONFIG_PARSER.get('inferencing', 'generator_path'),
|
||||
'embedder_path': CONFIG_PARSER.get('inferencing', 'embedder_path'),
|
||||
'source_path': CONFIG_PARSER.get('inferencing', 'source_path'),
|
||||
'target_path': CONFIG_PARSER.get('inferencing', 'target_path'),
|
||||
'output_path': CONFIG_PARSER.get('inferencing', 'output_path')
|
||||
}
|
||||
config_generator_path = CONFIG_PARSER.get('inferencing', 'generator_path')
|
||||
config_embedder_path = CONFIG_PARSER.get('inferencing', 'embedder_path')
|
||||
config_source_path = CONFIG_PARSER.get('inferencing', 'source_path')
|
||||
config_target_path = CONFIG_PARSER.get('inferencing', 'target_path')
|
||||
config_output_path = CONFIG_PARSER.get('inferencing', 'output_path')
|
||||
|
||||
state_dict = torch.load(config.get('generator_path')).get('state_dict').get('generator')
|
||||
state_dict = torch.load(config_generator_path).get('state_dict').get('generator')
|
||||
generator = Generator(CONFIG_PARSER)
|
||||
generator.load_state_dict(state_dict)
|
||||
generator.eval()
|
||||
embedder = torch.jit.load(config.get('embedder_path'), map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
embedder.eval()
|
||||
embedder = torch.jit.load(config_embedder_path, map_location = 'cpu').eval()
|
||||
|
||||
source_tensor = io.read_image(config.get('source_path'))
|
||||
target_tensor = io.read_image(config.get('target_path'))
|
||||
source_tensor = io.read_image(config_source_path)
|
||||
target_tensor = io.read_image(config_target_path)
|
||||
source_embedding = calc_embedding(embedder, source_tensor, (0, 0, 0, 0))
|
||||
output_tensor = generator(source_embedding, target_tensor)[0]
|
||||
io.write_jpeg(output_tensor, config.get('output_path'))
|
||||
io.write_jpeg(output_tensor, config_output_path)
|
||||
|
||||
@@ -35,10 +35,7 @@ class DiscriminatorLoss(nn.Module):
|
||||
class AdversarialLoss(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config =\
|
||||
{
|
||||
'adversarial_weight': config_parser.getfloat('training.losses', 'adversarial_weight')
|
||||
}
|
||||
self.config_adversarial_weight = config_parser.getfloat('training.losses', 'adversarial_weight')
|
||||
|
||||
def forward(self, discriminator_output_tensors : List[Tensor]) -> Tuple[Tensor, Tensor]:
|
||||
temp_tensors = []
|
||||
@@ -48,38 +45,32 @@ class AdversarialLoss(nn.Module):
|
||||
temp_tensors.append(temp_tensor)
|
||||
|
||||
adversarial_loss = torch.stack(temp_tensors).mean()
|
||||
weighted_adversarial_loss = adversarial_loss * self.config.get('adversarial_weight')
|
||||
weighted_adversarial_loss = adversarial_loss * self.config_adversarial_weight
|
||||
return adversarial_loss, weighted_adversarial_loss
|
||||
|
||||
|
||||
class AttributeLoss(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config =\
|
||||
{
|
||||
'batch_size': config_parser.getint('training.loader', 'batch_size'),
|
||||
'attribute_weight': config_parser.getfloat('training.losses', 'attribute_weight')
|
||||
}
|
||||
self.config_batch_size = config_parser.getint('training.loader', 'batch_size')
|
||||
self.config_attribute_weight = config_parser.getfloat('training.losses', 'attribute_weight')
|
||||
|
||||
def forward(self, target_attributes : Attributes, output_attributes : Attributes) -> Tuple[Tensor, Tensor]:
|
||||
temp_tensors = []
|
||||
|
||||
for target_attribute, output_attribute in zip(target_attributes, output_attributes):
|
||||
temp_tensor = torch.mean(torch.pow(output_attribute - target_attribute, 2).reshape(self.config.get('batch_size'), -1), dim = 1).mean()
|
||||
temp_tensor = torch.mean(torch.pow(output_attribute - target_attribute, 2).reshape(self.config_batch_size, -1), dim = 1).mean()
|
||||
temp_tensors.append(temp_tensor)
|
||||
|
||||
attribute_loss = torch.stack(temp_tensors).mean() * 0.5
|
||||
weighted_attribute_loss = attribute_loss * self.config.get('attribute_weight')
|
||||
weighted_attribute_loss = attribute_loss * self.config_attribute_weight
|
||||
return attribute_loss, weighted_attribute_loss
|
||||
|
||||
|
||||
class ReconstructionLoss(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser, embedder : EmbedderModule) -> None:
|
||||
super().__init__()
|
||||
self.config =\
|
||||
{
|
||||
'reconstruction_weight': config_parser.getfloat('training.losses', 'reconstruction_weight')
|
||||
}
|
||||
self.config_reconstruction_weight = config_parser.getfloat('training.losses', 'reconstruction_weight')
|
||||
self.embedder = embedder
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
@@ -94,7 +85,7 @@ class ReconstructionLoss(nn.Module):
|
||||
data_range = float(torch.max(output_tensor) - torch.min(output_tensor))
|
||||
visual_loss = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean()
|
||||
reconstruction_loss = (reconstruction_loss + visual_loss) * 0.5
|
||||
weighted_reconstruction_loss = reconstruction_loss * self.config.get('reconstruction_weight')
|
||||
weighted_reconstruction_loss = reconstruction_loss * self.config_reconstruction_weight
|
||||
return reconstruction_loss, weighted_reconstruction_loss
|
||||
|
||||
|
||||
@@ -156,11 +147,8 @@ class MotionLoss(nn.Module):
|
||||
class GazeLoss(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser, gazer : GazerModule) -> None:
|
||||
super().__init__()
|
||||
self.config =\
|
||||
{
|
||||
'gaze_weight': config_parser.getfloat('training.losses', 'gaze_weight'),
|
||||
'output_size': config_parser.getint('training.model.generator', 'output_size')
|
||||
}
|
||||
self.config_gaze_weight = config_parser.getfloat('training.losses', 'gaze_weight')
|
||||
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
|
||||
self.gazer = gazer
|
||||
self.l1_loss = nn.L1Loss()
|
||||
|
||||
@@ -172,11 +160,11 @@ class GazeLoss(nn.Module):
|
||||
yaw_loss = self.l1_loss(output_yaw, target_yaw)
|
||||
|
||||
gaze_loss = (pitch_loss + yaw_loss) * 0.5
|
||||
weighted_gaze_loss = gaze_loss * self.config.get('gaze_weight')
|
||||
weighted_gaze_loss = gaze_loss * self.config_gaze_weight
|
||||
return gaze_loss, weighted_gaze_loss
|
||||
|
||||
def detect_gaze(self, input_tensor : Tensor) -> Gaze:
|
||||
crop_sizes = (torch.tensor([ 0.235, 0.875, 0.0625, 0.8 ]) * self.config.get('output_size')).int()
|
||||
crop_sizes = (torch.tensor([ 0.235, 0.875, 0.0625, 0.8 ]) * self.config_output_size).int()
|
||||
crop_tensor = input_tensor[:, :, crop_sizes[0]:crop_sizes[1], crop_sizes[2]:crop_sizes[3]]
|
||||
crop_tensor = (crop_tensor + 1) * 0.5
|
||||
crop_tensor = transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ])(crop_tensor)
|
||||
|
||||
@@ -9,53 +9,50 @@ from ..types import Attributes, Embedding
|
||||
class AAD(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config =\
|
||||
{
|
||||
'identity_channels': config_parser.getint('training.model.generator', 'identity_channels'),
|
||||
'output_channels': config_parser.getint('training.model.generator', 'output_channels'),
|
||||
'output_size': config_parser.getint('training.model.generator', 'output_size'),
|
||||
'num_blocks': config_parser.getint('training.model.generator', 'num_blocks')
|
||||
}
|
||||
self.config_identity_channels = config_parser.getint('training.model.generator', 'identity_channels')
|
||||
self.config_output_channels = config_parser.getint('training.model.generator', 'output_channels')
|
||||
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
|
||||
self.config_num_blocks = config_parser.getint('training.model.generator', 'num_blocks')
|
||||
self.config_parser = config_parser
|
||||
self.pixel_shuffle_up_sample = PixelShuffleUpSample(self.config.get('identity_channels'), self.config.get('output_channels'))
|
||||
self.pixel_shuffle_up_sample = PixelShuffleUpSample(self.config_identity_channels, self.config_output_channels)
|
||||
self.layers = self.create_layers()
|
||||
|
||||
def create_layers(self) -> nn.ModuleList:
|
||||
layers = nn.ModuleList()
|
||||
|
||||
if self.config.get('output_size') == 128:
|
||||
if self.config_output_size == 128:
|
||||
layers.extend(
|
||||
[
|
||||
AdaptiveFeatureModulation(512, 512, 512, self.config.get('identity_channels'), self.config.get('num_blocks')),
|
||||
AdaptiveFeatureModulation(512, 512, 1024, self.config.get('identity_channels'), self.config.get('num_blocks')),
|
||||
AdaptiveFeatureModulation(512, 512, 512, self.config.get('identity_channels'), self.config.get('num_blocks'))
|
||||
AdaptiveFeatureModulation(512, 512, 512, self.config_identity_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(512, 512, 1024, self.config_identity_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(512, 512, 512, self.config_identity_channels, self.config_num_blocks)
|
||||
])
|
||||
|
||||
if self.config.get('output_size') == 256:
|
||||
if self.config_output_size == 256:
|
||||
layers.extend(
|
||||
[
|
||||
AdaptiveFeatureModulation(1024, 1024, 1024, self.config.get('identity_channels'), self.config.get('num_blocks')),
|
||||
AdaptiveFeatureModulation(1024, 1024, 2048, self.config.get('identity_channels'), self.config.get('num_blocks')),
|
||||
AdaptiveFeatureModulation(1024, 1024, 1024, self.config.get('identity_channels'), self.config.get('num_blocks')),
|
||||
AdaptiveFeatureModulation(1024, 512, 512, self.config.get('identity_channels'), self.config.get('num_blocks'))
|
||||
AdaptiveFeatureModulation(1024, 1024, 1024, self.config_identity_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 1024, 2048, self.config_identity_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 1024, 1024, self.config_identity_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 512, 512, self.config_identity_channels, self.config_num_blocks)
|
||||
])
|
||||
|
||||
if self.config.get('output_size') == 512:
|
||||
if self.config_output_size == 512:
|
||||
layers.extend(
|
||||
[
|
||||
AdaptiveFeatureModulation(2048, 2048, 2048, self.config.get('identity_channels'), self.config.get('num_blocks')),
|
||||
AdaptiveFeatureModulation(2048, 2048, 4096, self.config.get('identity_channels'), self.config.get('num_blocks')),
|
||||
AdaptiveFeatureModulation(2048, 2048, 2048, self.config.get('identity_channels'), self.config.get('num_blocks')),
|
||||
AdaptiveFeatureModulation(2048, 1024, 1024, self.config.get('identity_channels'), self.config.get('num_blocks')),
|
||||
AdaptiveFeatureModulation(1024, 512, 512, self.config.get('identity_channels'), self.config.get('num_blocks'))
|
||||
AdaptiveFeatureModulation(2048, 2048, 2048, self.config_identity_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(2048, 2048, 4096, self.config_identity_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(2048, 2048, 2048, self.config_identity_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(2048, 1024, 1024, self.config_identity_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 512, 512, self.config_identity_channels, self.config_num_blocks)
|
||||
])
|
||||
|
||||
layers.extend(
|
||||
[
|
||||
AdaptiveFeatureModulation(512, 256, 256, self.config.get('identity_channels'), self.config.get('num_blocks')),
|
||||
AdaptiveFeatureModulation(256, 128, 128, self.config.get('identity_channels'), self.config.get('num_blocks')),
|
||||
AdaptiveFeatureModulation(128, 64, 64, self.config.get('identity_channels'), self.config.get('num_blocks')),
|
||||
AdaptiveFeatureModulation(64, 3, 64, self.config.get('identity_channels'), self.config.get('num_blocks'))
|
||||
AdaptiveFeatureModulation(512, 256, 256, self.config_identity_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(256, 128, 128, self.config_identity_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(128, 64, 64, self.config_identity_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(64, 3, 64, self.config_identity_channels, self.config_num_blocks)
|
||||
])
|
||||
|
||||
return layers
|
||||
@@ -75,43 +72,40 @@ class AAD(nn.Module):
|
||||
class AdaptiveFeatureModulation(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int, attribute_channels : int, identity_channels : int, num_blocks : int) -> None:
|
||||
super().__init__()
|
||||
self.context =\
|
||||
{
|
||||
'input_channels': input_channels,
|
||||
'output_channels': output_channels,
|
||||
'attribute_channels': attribute_channels,
|
||||
'identity_channels': identity_channels,
|
||||
'num_blocks': num_blocks
|
||||
}
|
||||
self.context_input_channels = input_channels
|
||||
self.context_output_channels = output_channels
|
||||
self.context_attribute_channels = attribute_channels
|
||||
self.context_identity_channels = identity_channels
|
||||
self.context_num_blocks = num_blocks
|
||||
self.primary_layers = self.create_primary_layers()
|
||||
self.shortcut_layers = self.create_shortcut_layers()
|
||||
|
||||
def create_primary_layers(self) -> nn.ModuleList:
|
||||
primary_layers = nn.ModuleList()
|
||||
|
||||
for index in range(self.context.get('num_blocks')):
|
||||
for index in range(self.context_num_blocks):
|
||||
primary_layers.extend(
|
||||
[
|
||||
FeatureModulation(self.context.get('input_channels'), self.context.get('attribute_channels'), self.context.get('identity_channels')),
|
||||
FeatureModulation(self.context_input_channels, self.context_attribute_channels, self.context_identity_channels),
|
||||
nn.ReLU(inplace = True)
|
||||
])
|
||||
|
||||
if index < self.context.get('num_blocks') - 1:
|
||||
primary_layers.append(nn.Conv2d(self.context.get('input_channels'), self.context.get('input_channels'), kernel_size = 3, padding = 1, bias = False))
|
||||
if index < self.context_num_blocks - 1:
|
||||
primary_layers.append(nn.Conv2d(self.context_input_channels, self.context_input_channels, kernel_size = 3, padding = 1, bias = False))
|
||||
else:
|
||||
primary_layers.append(nn.Conv2d(self.context.get('input_channels'), self.context.get('output_channels'), kernel_size = 3, padding = 1, bias = False))
|
||||
primary_layers.append(nn.Conv2d(self.context_input_channels, self.context_output_channels, kernel_size = 3, padding = 1, bias = False))
|
||||
|
||||
return primary_layers
|
||||
|
||||
def create_shortcut_layers(self) -> nn.ModuleList:
|
||||
shortcut_layers = nn.ModuleList()
|
||||
|
||||
if self.context.get('input_channels') > self.context.get('output_channels'):
|
||||
if self.context_input_channels > self.context_output_channels:
|
||||
shortcut_layers.extend(
|
||||
[
|
||||
FeatureModulation(self.context.get('input_channels'), self.context.get('attribute_channels'), self.context.get('identity_channels')),
|
||||
FeatureModulation(self.context_input_channels, self.context_attribute_channels, self.context_identity_channels),
|
||||
nn.ReLU(inplace = True),
|
||||
nn.Conv2d(self.context.get('input_channels'), self.context.get('output_channels'), kernel_size = 3, padding = 1, bias = False)
|
||||
nn.Conv2d(self.context_input_channels, self.context_output_channels, kernel_size = 3, padding = 1, bias = False)
|
||||
])
|
||||
|
||||
return shortcut_layers
|
||||
@@ -125,7 +119,7 @@ class AdaptiveFeatureModulation(nn.Module):
|
||||
else:
|
||||
primary_tensor = primary_layer(primary_tensor)
|
||||
|
||||
if self.context.get('input_channels') > self.context.get('output_channels'):
|
||||
if self.context_input_channels > self.context_output_channels:
|
||||
shortcut_tensor = input_tensor
|
||||
|
||||
for shortcut_layer in self.shortcut_layers:
|
||||
@@ -142,10 +136,7 @@ class AdaptiveFeatureModulation(nn.Module):
|
||||
class FeatureModulation(nn.Module):
|
||||
def __init__(self, input_channels : int, attribute_channels : int, identity_channels : int) -> None:
|
||||
super().__init__()
|
||||
self.context =\
|
||||
{
|
||||
'input_channels': input_channels
|
||||
}
|
||||
self.context_input_channels = input_channels
|
||||
self.conv1 = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1)
|
||||
self.conv2 = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1)
|
||||
self.conv3 = nn.Conv2d(input_channels, 1, kernel_size = 1)
|
||||
@@ -160,8 +151,8 @@ class FeatureModulation(nn.Module):
|
||||
attribute_shift = self.conv2(attribute_embedding)
|
||||
attribute_modulation = attribute_scale * temp_tensor + attribute_shift
|
||||
|
||||
identity_scale = self.linear2(identity_embedding).reshape(temp_tensor.shape[0], self.context.get('input_channels'), 1, 1).expand_as(temp_tensor)
|
||||
identity_shift = self.linear1(identity_embedding).reshape(temp_tensor.shape[0], self.context.get('input_channels'), 1, 1).expand_as(temp_tensor)
|
||||
identity_scale = self.linear2(identity_embedding).reshape(temp_tensor.shape[0], self.context_input_channels, 1, 1).expand_as(temp_tensor)
|
||||
identity_shift = self.linear1(identity_embedding).reshape(temp_tensor.shape[0], self.context_input_channels, 1, 1).expand_as(temp_tensor)
|
||||
identity_modulation = identity_scale * temp_tensor + identity_shift
|
||||
|
||||
temp_mask = torch.sigmoid(self.conv3(temp_tensor))
|
||||
|
||||
@@ -28,17 +28,14 @@ CONFIG_PARSER.read('config.ini')
|
||||
class FaceSwapperTrainer(LightningModule):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config =\
|
||||
{
|
||||
'embedder_path': config_parser.get('training.model', 'embedder_path'),
|
||||
'gazer_path': config_parser.get('training.model', 'gazer_path'),
|
||||
'motion_extractor_path': config_parser.get('training.model', 'motion_extractor_path'),
|
||||
'learning_rate': config_parser.getfloat('training.trainer', 'learning_rate'),
|
||||
'preview_frequency': config_parser.getint('training.trainer', 'preview_frequency')
|
||||
}
|
||||
self.embedder = torch.jit.load(self.config.get('embedder_path'), map_location = 'cpu').eval() # type:ignore[no-untyped-call]
|
||||
self.gazer = torch.jit.load(self.config.get('gazer_path'), map_location = 'cpu').eval() # type:ignore[no-untyped-call]
|
||||
self.motion_extractor = torch.jit.load(self.config.get('motion_extractor_path'), map_location = 'cpu').eval() # type:ignore[no-untyped-call]
|
||||
self.config_embedder_path = config_parser.get('training.model', 'embedder_path')
|
||||
self.config_gazer_path = config_parser.get('training.model', 'gazer_path')
|
||||
self.config_motion_extractor_path = config_parser.get('training.model', 'motion_extractor_path')
|
||||
self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate')
|
||||
self.config_preview_frequency = config_parser.getint('training.trainer', 'preview_frequency')
|
||||
self.embedder = torch.jit.load(self.config_embedder_path, map_location = 'cpu').eval()
|
||||
self.gazer = torch.jit.load(self.config_gazer_path, map_location = 'cpu').eval()
|
||||
self.motion_extractor = torch.jit.load(self.config_motion_extractor_path, map_location = 'cpu').eval()
|
||||
self.generator = Generator(config_parser)
|
||||
self.discriminator = Discriminator(config_parser)
|
||||
self.discriminator_loss = DiscriminatorLoss()
|
||||
@@ -55,8 +52,8 @@ class FaceSwapperTrainer(LightningModule):
|
||||
return output_tensor
|
||||
|
||||
def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet]:
|
||||
generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config.get('learning_rate'), betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config.get('learning_rate'), betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
generator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(generator_optimizer, T_0 = 300, T_mult = 2)
|
||||
discriminator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(discriminator_optimizer, T_0 = 300, T_mult = 2)
|
||||
|
||||
@@ -113,7 +110,7 @@ class FaceSwapperTrainer(LightningModule):
|
||||
discriminator_optimizer.step()
|
||||
self.untoggle_optimizer(discriminator_optimizer)
|
||||
|
||||
if self.global_step % self.config.get('preview_frequency') == 0:
|
||||
if self.global_step % self.config_preview_frequency == 0:
|
||||
self.generate_preview(source_tensor, target_tensor, generator_output_tensor)
|
||||
|
||||
self.log('generator_loss', generator_loss, prog_bar = True)
|
||||
@@ -149,52 +146,43 @@ class FaceSwapperTrainer(LightningModule):
|
||||
|
||||
|
||||
def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]:
|
||||
config =\
|
||||
{
|
||||
'batch_size': CONFIG_PARSER.getint('training.loader', 'batch_size'),
|
||||
'num_workers': CONFIG_PARSER.getint('training.loader', 'num_workers')
|
||||
}
|
||||
config_batch_size = CONFIG_PARSER.getint('training.loader', 'batch_size')
|
||||
config_num_workers = CONFIG_PARSER.getint('training.loader', 'num_workers')
|
||||
|
||||
training_dataset, validate_dataset = split_dataset(dataset)
|
||||
training_loader = StatefulDataLoader(training_dataset, batch_size = config.get('batch_size'), shuffle = True, num_workers = config.get('num_workers'), drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
validation_loader = StatefulDataLoader(validate_dataset, batch_size = config.get('batch_size'), shuffle = False, num_workers = config.get('num_workers'), pin_memory = True, persistent_workers = True)
|
||||
training_loader = StatefulDataLoader(training_dataset, batch_size = config_batch_size, shuffle = True, num_workers = config_num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
validation_loader = StatefulDataLoader(validate_dataset, batch_size = config_batch_size, shuffle = False, num_workers = config_num_workers, pin_memory = True, persistent_workers = True)
|
||||
return training_loader, validation_loader
|
||||
|
||||
|
||||
def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[Tensor]]:
|
||||
config =\
|
||||
{
|
||||
'split_ratio': CONFIG_PARSER.getfloat('training.loader', 'split_ratio')
|
||||
}
|
||||
config_split_ratio = CONFIG_PARSER.getfloat('training.loader', 'split_ratio')
|
||||
|
||||
dataset_size = len(dataset) # type:ignore[arg-type]
|
||||
training_size = int(dataset_size * config.get('split_ratio'))
|
||||
training_size = int(dataset_size * config_split_ratio)
|
||||
validation_size = int(dataset_size - training_size)
|
||||
training_dataset, validate_dataset = random_split(dataset, [ training_size, validation_size ])
|
||||
return training_dataset, validate_dataset
|
||||
|
||||
|
||||
def create_trainer() -> Trainer:
|
||||
config =\
|
||||
{
|
||||
'max_epochs': CONFIG_PARSER.getint('training.trainer', 'max_epochs'),
|
||||
'precision': CONFIG_PARSER.get('training.trainer', 'precision'),
|
||||
'directory_path': CONFIG_PARSER.get('training.output', 'directory_path'),
|
||||
'file_pattern': CONFIG_PARSER.get('training.output', 'file_pattern')
|
||||
}
|
||||
config_max_epochs = CONFIG_PARSER.getint('training.trainer', 'max_epochs')
|
||||
config_precision = CONFIG_PARSER.get('training.trainer', 'precision')
|
||||
config_directory_path = CONFIG_PARSER.get('training.output', 'directory_path')
|
||||
config_file_pattern = CONFIG_PARSER.get('training.output', 'file_pattern')
|
||||
logger = TensorBoardLogger('.logs', name = 'face_swapper')
|
||||
|
||||
return Trainer(
|
||||
logger = logger,
|
||||
log_every_n_steps = 10,
|
||||
max_epochs = config.get('max_epochs'),
|
||||
precision = config.get('precision'),
|
||||
max_epochs = config_max_epochs,
|
||||
precision = config_precision,
|
||||
callbacks =
|
||||
[
|
||||
ModelCheckpoint(
|
||||
monitor = 'generator_loss',
|
||||
dirpath = config.get('directory_path'),
|
||||
filename = config.get('file_pattern'),
|
||||
dirpath = config_directory_path,
|
||||
filename = config_file_pattern,
|
||||
every_n_train_steps = 1000,
|
||||
save_top_k = 3,
|
||||
save_last = True
|
||||
@@ -205,10 +193,7 @@ def create_trainer() -> Trainer:
|
||||
|
||||
|
||||
def train() -> None:
|
||||
config =\
|
||||
{
|
||||
'resume_path': CONFIG_PARSER.get('training.output', 'resume_path')
|
||||
}
|
||||
config_resume_path = CONFIG_PARSER.get('training.output', 'resume_path')
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.set_float32_matmul_precision('high')
|
||||
@@ -218,7 +203,7 @@ def train() -> None:
|
||||
face_swapper_trainer = FaceSwapperTrainer(CONFIG_PARSER)
|
||||
trainer = create_trainer()
|
||||
|
||||
if os.path.isfile(config.get('resume_path')):
|
||||
trainer.fit(face_swapper_trainer, training_loader, validation_loader, ckpt_path = config.get('resume_path'))
|
||||
if os.path.isfile(config_resume_path):
|
||||
trainer.fit(face_swapper_trainer, training_loader, validation_loader, ckpt_path = config_resume_path)
|
||||
else:
|
||||
trainer.fit(face_swapper_trainer, training_loader, validation_loader)
|
||||
|
||||
Reference in New Issue
Block a user