mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Next (#93)
* add sync batchnorm * replace random.choice with hash * fifty percent reduction * fix discriminator input * restore dataset.py * Remove duplicates * add discriminator_ratio to config * fix onnx export bug: replace round() with int() * Fix embedding naming * Introduce ModelWithConfigCheckpoint callback (#86) * Fix dist ini * Style: Refactor typing and improve code clarity in training.py (#88) * Add type casting for trainer params * Add type casting for trainer params * Add type casting for trainer params * Remove inplace activations for torch.compile compatibility (#89) * Fix README * improvise with norm layers & weighted average * add skip layer * use gelu instead of leaky_relu * cleanup * cleanup * Update dependencies * Different defaults and enable validation * Different defaults and enable validation * Revert to higher batch size * Just use copy over copy2 --------- Co-authored-by: harisreedhar <h4harisreedhar.s.s@gmail.com> Co-authored-by: NeuroDonu <112660822+NeuroDonu@users.noreply.github.com> Co-authored-by: Harisreedhar <46858047+harisreedhar@users.noreply.github.com>
This commit is contained in:
+2
-2
@@ -32,7 +32,7 @@ file_pattern = .datasets/megaface/**/*.jpg
|
||||
|
||||
```
|
||||
[training.loader]
|
||||
batch_size = 256
|
||||
batch_size = 128
|
||||
num_workers = 8
|
||||
split_ratio = 0.95
|
||||
```
|
||||
@@ -90,7 +90,7 @@ python train.py
|
||||
Launch the TensorBoard to monitor the training.
|
||||
|
||||
```
|
||||
tensorboard --logdir=.logs
|
||||
tensorboard --logdir .logs
|
||||
```
|
||||
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ def export() -> None:
|
||||
config_opset_version = CONFIG_PARSER.getint('exporting', 'opset_version')
|
||||
|
||||
os.makedirs(config_directory_path, exist_ok = True)
|
||||
model = CrossFaceTrainer.load_from_checkpoint(config_source_path, map_location ='cpu').eval()
|
||||
model = CrossFaceTrainer.load_from_checkpoint(config_source_path, config_parser = CONFIG_PARSER, map_location = 'cpu').eval()
|
||||
model.ir_version = torch.tensor(config_ir_version)
|
||||
input_tensor = torch.randn(1, 512)
|
||||
torch.onnx.export(model, input_tensor, config_target_path, input_names = [ 'input' ], output_names = [ 'output' ], opset_version = config_opset_version)
|
||||
|
||||
@@ -1,28 +1,37 @@
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class CrossFace(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.layers = self.create_layers()
|
||||
self.leaky_relu = nn.LeakyReLU()
|
||||
self.sequence = self.create_sequence()
|
||||
self.linear = nn.Linear(512, 512)
|
||||
self.apply(init_weight)
|
||||
|
||||
@staticmethod
|
||||
def create_layers() -> nn.ModuleList:
|
||||
return nn.ModuleList(
|
||||
[
|
||||
def create_sequence() -> nn.Sequential:
|
||||
return nn.Sequential(
|
||||
nn.Linear(512, 1024),
|
||||
nn.LayerNorm(1024),
|
||||
nn.GELU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(1024, 2048),
|
||||
nn.LayerNorm(2048),
|
||||
nn.GELU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(2048, 1024),
|
||||
nn.LayerNorm(1024),
|
||||
nn.GELU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(1024, 512)
|
||||
])
|
||||
)
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> Tensor:
|
||||
output_tensor = input_tensor / torch.norm(input_tensor)
|
||||
temp_tensor = nn.functional.normalize(input_tensor, p = 2, dim = -1)
|
||||
return self.sequence(temp_tensor) + 0.2 * self.linear(temp_tensor)
|
||||
|
||||
for layer in self.layers[:-1]:
|
||||
output_tensor = self.leaky_relu(layer(output_tensor))
|
||||
|
||||
output_tensor = self.layers[-1](output_tensor)
|
||||
return output_tensor
|
||||
def init_weight(module : nn.Module) -> None:
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.xavier_normal_(module.weight)
|
||||
nn.init.constant_(module.bias, 0.01)
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import os
|
||||
import shutil
|
||||
from configparser import ConfigParser
|
||||
from typing import Tuple
|
||||
from pathlib import Path
|
||||
from typing import Tuple, cast
|
||||
|
||||
import torch
|
||||
from lightning import LightningModule, Trainer
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint, StochasticWeightAveraging
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
from torch import Tensor, nn
|
||||
from torch.utils.data import Dataset, random_split
|
||||
@@ -12,7 +14,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from .dataset import StaticDataset
|
||||
from .models.crossface import CrossFace
|
||||
from .types import Batch, Embedding, OptimizerSet
|
||||
from .types import Batch, Embedding, OptimizerSet, TrainerPrecision, TrainerStrategy
|
||||
|
||||
CONFIG_PARSER = ConfigParser()
|
||||
CONFIG_PARSER.read('config.ini')
|
||||
@@ -67,6 +69,13 @@ class CrossFaceTrainer(LightningModule):
|
||||
return optimizer_set
|
||||
|
||||
|
||||
class ModelWithConfigCheckpoint(ModelCheckpoint):
|
||||
def _save_checkpoint(self, trainer : Trainer, checkpoint_path : str) -> None:
|
||||
super()._save_checkpoint(trainer, checkpoint_path)
|
||||
config_path = Path(checkpoint_path).with_suffix('.ini')
|
||||
shutil.copy('config.ini', config_path)
|
||||
|
||||
|
||||
def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]:
|
||||
config_batch_size = CONFIG_PARSER.getint('training.loader', 'batch_size')
|
||||
config_num_workers = CONFIG_PARSER.getint('training.loader', 'num_workers')
|
||||
@@ -89,8 +98,8 @@ def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[T
|
||||
|
||||
def create_trainer() -> Trainer:
|
||||
config_max_epochs = CONFIG_PARSER.getint('training.trainer', 'max_epochs')
|
||||
config_strategy = CONFIG_PARSER.get('training.trainer', 'strategy')
|
||||
config_precision = CONFIG_PARSER.get('training.trainer', 'precision')
|
||||
config_strategy = cast(TrainerStrategy, CONFIG_PARSER.get('training.trainer', 'strategy'))
|
||||
config_precision = cast(TrainerPrecision, CONFIG_PARSER.get('training.trainer', 'precision'))
|
||||
config_logger_path = CONFIG_PARSER.get('training.logger', 'logger_path')
|
||||
config_logger_name = CONFIG_PARSER.get('training.logger', 'logger_name')
|
||||
config_directory_path = CONFIG_PARSER.get('training.output', 'directory_path')
|
||||
@@ -105,15 +114,17 @@ def create_trainer() -> Trainer:
|
||||
precision = config_precision,
|
||||
callbacks =
|
||||
[
|
||||
ModelCheckpoint(
|
||||
ModelWithConfigCheckpoint(
|
||||
monitor = 'training_loss',
|
||||
dirpath = config_directory_path,
|
||||
filename = config_file_pattern,
|
||||
every_n_epochs = 1,
|
||||
every_n_epochs = 1000,
|
||||
save_top_k = 5,
|
||||
save_last = True
|
||||
)
|
||||
]
|
||||
),
|
||||
StochasticWeightAveraging(swa_lrs = 1e-2)
|
||||
],
|
||||
val_check_interval = 1000
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, TypeAlias
|
||||
from typing import Any, Literal, TypeAlias
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
@@ -6,3 +6,6 @@ Batch : TypeAlias = Tensor
|
||||
Embedding : TypeAlias = Tensor
|
||||
|
||||
OptimizerSet : TypeAlias = Any
|
||||
|
||||
TrainerStrategy = Literal['auto', 'ddp', 'ddp_spawn', 'ddp_find_unused_parameters_true']
|
||||
TrainerPrecision = Literal['64-true', '32-true', '16-true', '16-mixed', 'bf16-true', 'bf16-mixed', 'transformer-engine', 'transformer-engine-float16']
|
||||
|
||||
+3
-2
@@ -54,7 +54,6 @@ face_masker_path = .models/face_masker.pt
|
||||
```
|
||||
[training.model.generator]
|
||||
source_channels = 512
|
||||
output_channels = 4096
|
||||
output_size = 256
|
||||
num_blocks = 2
|
||||
```
|
||||
@@ -89,10 +88,12 @@ mask_weight = 5.0
|
||||
```
|
||||
[training.trainer]
|
||||
accumulate_size = 4
|
||||
discriminator_ratio = 0.4
|
||||
gradient_clip = 20.0
|
||||
max_epochs = 50
|
||||
strategy = auto
|
||||
precision = 16-mixed
|
||||
sync_batchnorm = false
|
||||
preview_frequency = 100
|
||||
```
|
||||
|
||||
@@ -164,7 +165,7 @@ python train.py
|
||||
Launch the TensorBoard to monitor the training.
|
||||
|
||||
```
|
||||
tensorboard --logdir=.logs
|
||||
tensorboard --logdir .logs
|
||||
```
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[training.dataset]
|
||||
file_pattern =
|
||||
convert_template =
|
||||
multiplier
|
||||
multiplier =
|
||||
transform_size =
|
||||
usage_mode =
|
||||
batch_mode =
|
||||
@@ -20,7 +20,6 @@ face_masker_path =
|
||||
|
||||
[training.model.generator]
|
||||
source_channels =
|
||||
output_channels =
|
||||
output_size =
|
||||
num_blocks =
|
||||
|
||||
@@ -47,10 +46,12 @@ mask_weight =
|
||||
|
||||
[training.trainer]
|
||||
accumulate_size =
|
||||
discriminator_ratio =
|
||||
gradient_clip =
|
||||
max_epochs =
|
||||
strategy =
|
||||
precision =
|
||||
sync_batchnorm =
|
||||
preview_frequency =
|
||||
|
||||
[training.modifier]
|
||||
|
||||
+38
-58
@@ -43,6 +43,35 @@ class DynamicDataset(Dataset[Tensor]):
|
||||
def __len__(self) -> int:
|
||||
return len(resolve_static_file_pattern(self.config_file_pattern))
|
||||
|
||||
def prepare_equal_batch(self, source_path : str) -> Batch:
|
||||
return self.create_batch(source_path, source_path, self.config_convert_template, self.config_convert_template)
|
||||
|
||||
def prepare_same_batch(self, source_path : str) -> Batch:
|
||||
target_directory_path = os.path.dirname(source_path)
|
||||
target_file_name_and_extension = random.choice(os.listdir(target_directory_path))
|
||||
target_path = os.path.join(target_directory_path, target_file_name_and_extension)
|
||||
return self.create_batch(source_path, target_path, self.config_convert_template, self.config_convert_template)
|
||||
|
||||
def prepare_source_batch(self, source_path : str) -> Batch:
|
||||
config_parser = self.filter_config_by_usage_mode('both')
|
||||
config_section = random.choice(config_parser.sections())
|
||||
config_file_pattern = config_parser.get(config_section, 'file_pattern')
|
||||
config_convert_template = cast(ConvertTemplate, config_parser.get(config_section, 'convert_template'))
|
||||
target_path = random.choice(resolve_static_file_pattern(config_file_pattern))
|
||||
return self.create_batch(source_path, target_path, self.config_convert_template, config_convert_template)
|
||||
|
||||
def prepare_target_batch(self, target_path : str) -> Batch:
|
||||
config_parser = self.filter_config_by_usage_mode('both')
|
||||
config_section = random.choice(config_parser.sections())
|
||||
config_file_pattern = config_parser.get(config_section, 'file_pattern')
|
||||
config_convert_template = cast(ConvertTemplate, config_parser.get(config_section, 'convert_template'))
|
||||
source_path = random.choice(resolve_static_file_pattern(config_file_pattern))
|
||||
return self.create_batch(source_path, target_path, config_convert_template, self.config_convert_template)
|
||||
|
||||
def prepare_different_batch(self, source_path : str) -> Batch:
|
||||
target_path = random.choice(resolve_static_file_pattern(self.config_file_pattern))
|
||||
return self.create_batch(source_path, target_path, self.config_convert_template, self.config_convert_template)
|
||||
|
||||
def compose_transforms(self) -> transforms:
|
||||
return transforms.Compose(
|
||||
[
|
||||
@@ -53,64 +82,6 @@ class DynamicDataset(Dataset[Tensor]):
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
|
||||
def prepare_equal_batch(self, source_path : str) -> Batch:
|
||||
source_tensor = io.read_image(source_path)
|
||||
source_tensor = self.transforms(source_tensor)
|
||||
source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template)
|
||||
return source_tensor, source_tensor
|
||||
|
||||
def prepare_same_batch(self, source_path : str) -> Batch:
|
||||
target_directory_path = os.path.dirname(source_path)
|
||||
target_file_name_and_extension = random.choice(os.listdir(target_directory_path))
|
||||
target_path = os.path.join(target_directory_path, target_file_name_and_extension)
|
||||
source_tensor = io.read_image(source_path)
|
||||
source_tensor = self.transforms(source_tensor)
|
||||
source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template)
|
||||
target_tensor = io.read_image(target_path)
|
||||
target_tensor = self.transforms(target_tensor)
|
||||
target_tensor = self.conditional_convert_tensor(target_tensor, self.config_convert_template)
|
||||
return source_tensor, target_tensor
|
||||
|
||||
def prepare_source_batch(self, source_path : str) -> Batch:
|
||||
config_parser = self.filter_config_by_usage_mode('both')
|
||||
config_section = random.choice(config_parser.sections())
|
||||
config_file_pattern = config_parser.get(config_section, 'file_pattern')
|
||||
config_convert_template = cast(ConvertTemplate, config_parser.get(config_section, 'convert_template'))
|
||||
|
||||
target_path = random.choice(resolve_static_file_pattern(config_file_pattern))
|
||||
source_tensor = io.read_image(source_path)
|
||||
source_tensor = self.transforms(source_tensor)
|
||||
source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template)
|
||||
target_tensor = io.read_image(target_path)
|
||||
target_tensor = self.transforms(target_tensor)
|
||||
target_tensor = self.conditional_convert_tensor(target_tensor, config_convert_template)
|
||||
return source_tensor, target_tensor
|
||||
|
||||
def prepare_target_batch(self, target_path : str) -> Batch:
|
||||
config_parser = self.filter_config_by_usage_mode('both')
|
||||
config_section = random.choice(config_parser.sections())
|
||||
config_file_pattern = config_parser.get(config_section, 'file_pattern')
|
||||
config_convert_template = cast(ConvertTemplate, config_parser.get(config_section, 'convert_template'))
|
||||
|
||||
source_path = random.choice(resolve_static_file_pattern(config_file_pattern))
|
||||
source_tensor = io.read_image(source_path)
|
||||
source_tensor = self.transforms(source_tensor)
|
||||
source_tensor = self.conditional_convert_tensor(source_tensor, config_convert_template)
|
||||
target_tensor = io.read_image(target_path)
|
||||
target_tensor = self.transforms(target_tensor)
|
||||
target_tensor = self.conditional_convert_tensor(target_tensor, self.config_convert_template)
|
||||
return source_tensor, target_tensor
|
||||
|
||||
def prepare_different_batch(self, source_path : str) -> Batch:
|
||||
target_path = random.choice(resolve_static_file_pattern(self.config_file_pattern))
|
||||
source_tensor = io.read_image(source_path)
|
||||
source_tensor = self.transforms(source_tensor)
|
||||
source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template)
|
||||
target_tensor = io.read_image(target_path)
|
||||
target_tensor = self.transforms(target_tensor)
|
||||
target_tensor = self.conditional_convert_tensor(target_tensor, self.config_convert_template)
|
||||
return source_tensor, target_tensor
|
||||
|
||||
def filter_config_by_usage_mode(self, usage_mode : UsageMode) -> ConfigParser:
|
||||
config_parser = ConfigParser()
|
||||
|
||||
@@ -126,6 +97,15 @@ class DynamicDataset(Dataset[Tensor]):
|
||||
|
||||
return config_parser
|
||||
|
||||
def create_batch(self, source_path : str, target_path : str, source_convert_template : ConvertTemplate, target_convert_template : ConvertTemplate) -> Batch:
|
||||
source_tensor = io.read_image(source_path)
|
||||
source_tensor = self.transforms(source_tensor)
|
||||
source_tensor = self.conditional_convert_tensor(source_tensor, source_convert_template)
|
||||
target_tensor = io.read_image(target_path)
|
||||
target_tensor = self.transforms(target_tensor)
|
||||
target_tensor = self.conditional_convert_tensor(target_tensor, target_convert_template)
|
||||
return source_tensor, target_tensor
|
||||
|
||||
@staticmethod
|
||||
def conditional_convert_tensor(input_tensor : Tensor, convert_template : ConvertTemplate) -> Tensor:
|
||||
if convert_template:
|
||||
|
||||
@@ -36,7 +36,7 @@ def export() -> None:
|
||||
config_precision = CONFIG_PARSER.get('exporting', 'precision')
|
||||
|
||||
os.makedirs(config_directory_path, exist_ok = True)
|
||||
model = HyperSwapTrainer.load_from_checkpoint(config_source_path, config_parser = CONFIG_PARSER, map_location ='cpu').eval()
|
||||
model = HyperSwapTrainer.load_from_checkpoint(config_source_path, config_parser = CONFIG_PARSER, map_location = 'cpu').eval()
|
||||
|
||||
if config_precision == 'half':
|
||||
model = HalfPrecision(model).eval()
|
||||
|
||||
@@ -34,7 +34,7 @@ def convert_tensor(input_tensor : Tensor, convert_template : ConvertTemplate) ->
|
||||
return output_tensor
|
||||
|
||||
|
||||
def calculate_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : Padding) -> Embedding:
|
||||
def calculate_face_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : Padding) -> Embedding:
|
||||
crop_tensor = convert_tensor(input_tensor, 'arcface_128_to_arcface_112_v2')
|
||||
crop_tensor = nn.functional.interpolate(crop_tensor, size = 112, mode = 'area')
|
||||
crop_tensor[:, :, :padding[0], :] = 0
|
||||
@@ -42,9 +42,9 @@ def calculate_embedding(embedder : EmbedderModule, input_tensor : Tensor, paddin
|
||||
crop_tensor[:, :, :, :padding[2]] = 0
|
||||
crop_tensor[:, :, :, 112 - padding[3]:] = 0
|
||||
|
||||
embedding = embedder(crop_tensor)
|
||||
embedding = nn.functional.normalize(embedding, p = 2)
|
||||
return embedding
|
||||
face_embedding = embedder(crop_tensor)
|
||||
face_embedding = nn.functional.normalize(face_embedding, p = 2)
|
||||
return face_embedding
|
||||
|
||||
|
||||
def overlay_mask(input_tensor : Tensor, input_mask : Mask) -> Tensor:
|
||||
@@ -56,7 +56,7 @@ def overlay_mask(input_tensor : Tensor, input_mask : Mask) -> Tensor:
|
||||
|
||||
|
||||
def dilate_mask(input_tensor : Tensor, factor : float) -> Tensor:
|
||||
padding = round(input_tensor.shape[2] * factor)
|
||||
padding = int(input_tensor.shape[2] * factor + 0.5)
|
||||
kernel_size = 1 + 2 * padding
|
||||
temp_tensor = nn.functional.pad(input_tensor, (padding, padding, padding, padding), mode = 'replicate')
|
||||
output_tensor = nn.functional.max_pool2d(temp_tensor, kernel_size = kernel_size, stride = 1, padding = 0)
|
||||
@@ -64,7 +64,7 @@ def dilate_mask(input_tensor : Tensor, factor : float) -> Tensor:
|
||||
|
||||
|
||||
def erode_mask(input_tensor : Tensor, factor : float) -> Tensor:
|
||||
padding = round(input_tensor.shape[2] * factor)
|
||||
padding = int(input_tensor.shape[2] * factor + 0.5)
|
||||
kernel_size = 1 + 2 * padding
|
||||
temp_tensor = 1 - nn.functional.pad(input_tensor, (padding, padding, padding, padding), mode = 'replicate')
|
||||
output_tensor = 1 - nn.functional.max_pool2d(temp_tensor, kernel_size = kernel_size, stride = 1, padding = 0)
|
||||
|
||||
@@ -3,7 +3,7 @@ import configparser
|
||||
import torch
|
||||
from torchvision import io
|
||||
|
||||
from .helper import calculate_embedding
|
||||
from .helper import calculate_face_embedding
|
||||
from .training import HyperSwapTrainer
|
||||
|
||||
CONFIG_PARSER = configparser.ConfigParser()
|
||||
@@ -22,6 +22,6 @@ def infer() -> None:
|
||||
|
||||
source_tensor = io.read_image(config_source_path)
|
||||
target_tensor = io.read_image(config_target_path)
|
||||
source_embedding = calculate_embedding(embedder, source_tensor, (0, 0, 0, 0))
|
||||
source_embedding = calculate_face_embedding(embedder, source_tensor, (0, 0, 0, 0))
|
||||
output_tensor, _ = generator(source_embedding, target_tensor)
|
||||
io.write_jpeg(output_tensor, config_output_path)
|
||||
|
||||
@@ -6,7 +6,7 @@ from pytorch_msssim import ssim
|
||||
from torch import Tensor, nn
|
||||
from torchvision import transforms
|
||||
|
||||
from ..helper import calculate_embedding, dilate_mask
|
||||
from ..helper import calculate_face_embedding, dilate_mask
|
||||
from ..types import EmbedderModule, FaceMaskerModule, Feature, GazerModule, Loss, Mask
|
||||
|
||||
|
||||
@@ -14,16 +14,16 @@ class DiscriminatorLoss(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, discriminator_source_tensors : List[Tensor], discriminator_output_tensors : List[Tensor]) -> Loss:
|
||||
def forward(self, discriminator_real_tensors : List[Tensor], discriminator_fake_tensors : List[Tensor]) -> Loss:
|
||||
positive_tensors = []
|
||||
negative_tensors = []
|
||||
|
||||
for discriminator_source_tensor in discriminator_source_tensors:
|
||||
positive_tensor = torch.relu(1 - discriminator_source_tensor).mean(dim = [ 1, 2, 3 ])
|
||||
for discriminator_real_tensor in discriminator_real_tensors:
|
||||
positive_tensor = torch.relu(1 - discriminator_real_tensor).mean(dim = [ 1, 2, 3 ])
|
||||
positive_tensors.append(positive_tensor)
|
||||
|
||||
for discriminator_output_tensor in discriminator_output_tensors:
|
||||
negative_tensor = torch.relu(discriminator_output_tensor + 1).mean(dim = [ 1, 2, 3 ])
|
||||
for discriminator_fake_tensor in discriminator_fake_tensors:
|
||||
negative_tensor = torch.relu(discriminator_fake_tensor + 1).mean(dim = [ 1, 2, 3 ])
|
||||
negative_tensors.append(negative_tensor)
|
||||
|
||||
positive_loss = torch.stack(positive_tensors).mean()
|
||||
@@ -97,8 +97,8 @@ class ReconstructionLoss(nn.Module):
|
||||
|
||||
def forward(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss]:
|
||||
with torch.no_grad():
|
||||
source_embedding = calculate_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
|
||||
target_embedding = calculate_embedding(self.embedder, target_tensor, (0, 0, 0, 0))
|
||||
source_embedding = calculate_face_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
|
||||
target_embedding = calculate_face_embedding(self.embedder, target_tensor, (0, 0, 0, 0))
|
||||
|
||||
has_similar_identity = torch.cosine_similarity(source_embedding, target_embedding) > 0.8
|
||||
|
||||
@@ -120,8 +120,8 @@ class IdentityLoss(nn.Module):
|
||||
self.embedder = embedder
|
||||
|
||||
def forward(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss]:
|
||||
output_embedding = calculate_embedding(self.embedder, output_tensor, (30, 0, 10, 10))
|
||||
source_embedding = calculate_embedding(self.embedder, source_tensor, (30, 0, 10, 10))
|
||||
output_embedding = calculate_face_embedding(self.embedder, output_tensor, (30, 0, 10, 10))
|
||||
source_embedding = calculate_face_embedding(self.embedder, source_tensor, (30, 0, 10, 10))
|
||||
identity_loss = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean()
|
||||
weighted_identity_loss = identity_loss * self.config_identity_weight
|
||||
return identity_loss, weighted_identity_loss
|
||||
|
||||
@@ -11,10 +11,9 @@ class AAD(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config_source_channels = config_parser.getint('training.model.generator', 'source_channels')
|
||||
self.config_output_channels = config_parser.getint('training.model.generator', 'output_channels')
|
||||
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
|
||||
self.config_num_blocks = config_parser.getint('training.model.generator', 'num_blocks')
|
||||
self.pixel_shuffle_up_sample = PixelShuffleUpSample(self.config_source_channels, self.config_output_channels)
|
||||
self.pixel_shuffle_up_sample = PixelShuffleUpSample(self.config_source_channels, 4096)
|
||||
self.layers = self.create_layers()
|
||||
|
||||
def create_layers(self) -> nn.ModuleList:
|
||||
@@ -23,9 +22,9 @@ class AAD(nn.Module):
|
||||
if self.config_output_size == 128:
|
||||
layers.extend(
|
||||
[
|
||||
AdaptiveFeatureModulation(512, 512, 512, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(512, 512, 1024, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(512, 512, 512, self.config_source_channels, self.config_num_blocks)
|
||||
AdaptiveFeatureModulation(1024, 1024, 512, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 1024, 1024, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks)
|
||||
])
|
||||
|
||||
if self.config_output_size == 256:
|
||||
@@ -40,21 +39,21 @@ class AAD(nn.Module):
|
||||
if self.config_output_size == 512:
|
||||
layers.extend(
|
||||
[
|
||||
AdaptiveFeatureModulation(2048, 2048, 2048, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(2048, 2048, 4096, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(2048, 2048, 2048, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(2048, 1024, 1024, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 1024, 1024, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 1024, 2048, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 1024, 1536, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 1024, 768, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks)
|
||||
])
|
||||
|
||||
if self.config_output_size == 1024:
|
||||
layers.extend(
|
||||
[
|
||||
AdaptiveFeatureModulation(4096, 4096, 4096, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(4096, 4096, 8192, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(4096, 4096, 4096, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(4096, 2048, 2048, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(2048, 1024, 1024, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 1024, 2048, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 1024, 4096, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 1024, 3072, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 1024, 1536, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 1024, 768, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks)
|
||||
])
|
||||
|
||||
@@ -100,7 +99,7 @@ class AdaptiveFeatureModulation(nn.Module):
|
||||
primary_layers.extend(
|
||||
[
|
||||
FeatureModulation(self.context_input_channels, self.context_target_channels, self.context_source_channels),
|
||||
nn.ReLU(inplace = True)
|
||||
nn.ReLU()
|
||||
])
|
||||
|
||||
if index < self.context_num_blocks - 1:
|
||||
@@ -117,7 +116,7 @@ class AdaptiveFeatureModulation(nn.Module):
|
||||
shortcut_layers.extend(
|
||||
[
|
||||
FeatureModulation(self.context_input_channels, self.context_target_channels, self.context_source_channels),
|
||||
nn.ReLU(inplace = True),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(self.context_input_channels, self.context_output_channels, kernel_size = 3, padding = 1, bias = False)
|
||||
])
|
||||
|
||||
|
||||
@@ -56,17 +56,17 @@ class BottleNeck(nn.Module):
|
||||
def __init__(self, num_filters : int):
|
||||
super().__init__()
|
||||
self.sequences = self.create_sequences(num_filters)
|
||||
self.relu = nn.ReLU(inplace = True)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
@staticmethod
|
||||
def create_sequences(num_filters : int) -> nn.Sequential:
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(num_filters, num_filters, kernel_size = 3, padding = 1, bias = False),
|
||||
nn.BatchNorm2d(num_filters),
|
||||
nn.ReLU(inplace = True),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(num_filters, num_filters, kernel_size = 3, padding = 1, bias = False),
|
||||
nn.BatchNorm2d(num_filters),
|
||||
nn.ReLU(inplace = True)
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> Tensor:
|
||||
@@ -84,7 +84,7 @@ class UpSample(nn.Module):
|
||||
def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential:
|
||||
return nn.Sequential(
|
||||
nn.ConvTranspose2d(input_channels, output_channels, kernel_size = 2, stride = 2),
|
||||
nn.ReLU(inplace = True)
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> Tensor:
|
||||
@@ -102,7 +102,7 @@ class DownSample(nn.Module):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1, bias = False),
|
||||
nn.BatchNorm2d(output_channels),
|
||||
nn.ReLU(inplace = True),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(2)
|
||||
)
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ class NLD(nn.Module):
|
||||
layers = nn.ModuleList(
|
||||
[
|
||||
nn.Conv2d(self.config_input_channels, current_filters, kernel_size = self.config_kernel_size, stride = 2, padding = padding),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
nn.LeakyReLU(0.2)
|
||||
])
|
||||
|
||||
for _ in range(1, self.config_num_layers):
|
||||
@@ -30,7 +30,7 @@ class NLD(nn.Module):
|
||||
[
|
||||
nn.Conv2d(previous_filters, current_filters, kernel_size = self.config_kernel_size, stride = 2, padding = padding),
|
||||
nn.InstanceNorm2d(current_filters),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
nn.LeakyReLU(0.2)
|
||||
]
|
||||
|
||||
previous_filters = current_filters
|
||||
@@ -39,7 +39,7 @@ class NLD(nn.Module):
|
||||
[
|
||||
nn.Conv2d(previous_filters, current_filters, kernel_size = self.config_kernel_size, padding = padding),
|
||||
nn.InstanceNorm2d(current_filters),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.Conv2d(current_filters, 1, kernel_size = self.config_kernel_size, padding = padding)
|
||||
]
|
||||
return layers
|
||||
|
||||
@@ -41,8 +41,8 @@ class UNet(nn.Module):
|
||||
down_samples.extend(
|
||||
[
|
||||
DownSample(512, 1024),
|
||||
DownSample(1024, 2048),
|
||||
DownSample(2048, 2048)
|
||||
DownSample(1024, 1024),
|
||||
DownSample(1024, 1024)
|
||||
])
|
||||
|
||||
if self.config_output_size == 1024:
|
||||
@@ -50,8 +50,8 @@ class UNet(nn.Module):
|
||||
[
|
||||
DownSample(512, 1024),
|
||||
DownSample(1024, 2048),
|
||||
DownSample(2048, 4096),
|
||||
DownSample(4096, 4096)
|
||||
DownSample(2048, 2048),
|
||||
DownSample(2048, 2048)
|
||||
])
|
||||
|
||||
return down_samples
|
||||
@@ -62,36 +62,39 @@ class UNet(nn.Module):
|
||||
if self.config_output_size == 128:
|
||||
up_samples.extend(
|
||||
[
|
||||
UpSample(512, 512)
|
||||
UpSample(512, 512),
|
||||
UpSample(1024, 256)
|
||||
])
|
||||
|
||||
if self.config_output_size == 256:
|
||||
up_samples.extend(
|
||||
[
|
||||
UpSample(1024, 1024),
|
||||
UpSample(2048, 512)
|
||||
UpSample(2048, 512),
|
||||
UpSample(1024, 256)
|
||||
])
|
||||
|
||||
if self.config_output_size == 512:
|
||||
up_samples.extend(
|
||||
[
|
||||
UpSample(2048, 2048),
|
||||
UpSample(4096, 1024),
|
||||
UpSample(2048, 512)
|
||||
UpSample(1024, 1024),
|
||||
UpSample(2048, 512),
|
||||
UpSample(1536, 256),
|
||||
UpSample(768, 256)
|
||||
])
|
||||
|
||||
if self.config_output_size == 1024:
|
||||
up_samples.extend(
|
||||
[
|
||||
UpSample(4096, 4096),
|
||||
UpSample(8192, 2048),
|
||||
UpSample(2048, 2048),
|
||||
UpSample(4096, 1024),
|
||||
UpSample(2048, 512)
|
||||
UpSample(3072, 512),
|
||||
UpSample(1536, 256),
|
||||
UpSample(768, 256)
|
||||
])
|
||||
|
||||
up_samples.extend(
|
||||
[
|
||||
UpSample(1024, 256),
|
||||
UpSample(512, 128),
|
||||
UpSample(256, 64),
|
||||
UpSample(128, 32)
|
||||
@@ -130,7 +133,7 @@ class UpSample(nn.Module):
|
||||
return nn.Sequential(
|
||||
nn.ConvTranspose2d(input_channels, output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False),
|
||||
nn.BatchNorm2d(output_channels),
|
||||
nn.LeakyReLU(0.1, inplace = True)
|
||||
nn.LeakyReLU(0.1)
|
||||
)
|
||||
|
||||
def forward(self, input_tensor : Tensor, skip_tensor : Tensor) -> Tensor:
|
||||
@@ -149,7 +152,7 @@ class DownSample(nn.Module):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(input_channels, output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False),
|
||||
nn.BatchNorm2d(output_channels),
|
||||
nn.LeakyReLU(0.1, inplace = True)
|
||||
nn.LeakyReLU(0.1)
|
||||
)
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> Tensor:
|
||||
|
||||
+28
-14
@@ -1,8 +1,10 @@
|
||||
import os
|
||||
import shutil
|
||||
import warnings
|
||||
from configparser import ConfigParser
|
||||
from copy import deepcopy
|
||||
from typing import List, Tuple
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, cast
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
@@ -14,11 +16,11 @@ from torch.utils.data import ConcatDataset, Dataset, random_split
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from .dataset import DynamicDataset
|
||||
from .helper import apply_noise, calculate_embedding, erode_mask, overlay_mask
|
||||
from .helper import apply_noise, calculate_face_embedding, erode_mask, overlay_mask
|
||||
from .models.discriminator import Discriminator
|
||||
from .models.generator import Generator
|
||||
from .models.loss import AdversarialLoss, CycleLoss, DiscriminatorLoss, FeatureLoss, GazeLoss, IdentityLoss, MaskLoss, ReconstructionLoss
|
||||
from .types import Batch, Embedding, Mask, OptimizerSet
|
||||
from .types import Batch, Embedding, Mask, OptimizerSet, TrainerPrecision, TrainerStrategy
|
||||
|
||||
warnings.filterwarnings('ignore', category = UserWarning, module = 'torch')
|
||||
|
||||
@@ -34,6 +36,7 @@ class HyperSwapTrainer(LightningModule):
|
||||
self.config_gazer_path = config_parser.get('training.model', 'gazer_path')
|
||||
self.config_face_masker_path = config_parser.get('training.model', 'face_masker_path')
|
||||
self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size')
|
||||
self.config_discriminator_ratio = config_parser.getfloat('training.trainer', 'discriminator_ratio')
|
||||
self.config_gradient_clip = config_parser.getfloat('training.trainer', 'gradient_clip')
|
||||
self.config_preview_frequency = config_parser.getint('training.trainer', 'preview_frequency')
|
||||
self.config_mask_factor = config_parser.getfloat('training.modifier', 'mask_factor')
|
||||
@@ -101,8 +104,8 @@ class HyperSwapTrainer(LightningModule):
|
||||
do_update = (batch_index + 1) % self.config_accumulate_size == 0
|
||||
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
|
||||
generator_scheduler, discriminator_scheduler = self.lr_schedulers() #type:ignore[attr-defined]
|
||||
source_embedding = calculate_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0))
|
||||
target_embedding = calculate_embedding(self.generator_embedder, target_tensor, (0, 0, 0, 0))
|
||||
source_embedding = calculate_face_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0))
|
||||
target_embedding = calculate_face_embedding(self.generator_embedder, target_tensor, (0, 0, 0, 0))
|
||||
|
||||
if self.config_noise_factor > 0:
|
||||
source_embedding = apply_noise(source_embedding, self.config_noise_factor)
|
||||
@@ -123,9 +126,12 @@ class HyperSwapTrainer(LightningModule):
|
||||
mask_loss, weighted_mask_loss = self.mask_loss(target_tensor, generator_output_mask)
|
||||
generator_loss = weighted_adversarial_loss + weighted_cycle_loss + weighted_feature_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_gaze_loss + weighted_mask_loss
|
||||
|
||||
discriminator_source_tensors = self.discriminator(source_tensor)
|
||||
discriminator_output_tensors = self.discriminator(generator_output_tensor.detach())
|
||||
discriminator_loss = self.discriminator_loss(discriminator_source_tensors, discriminator_output_tensors)
|
||||
if torch.randn(1).item() < self.config_discriminator_ratio:
|
||||
discriminator_real_tensors = self.discriminator(source_tensor)
|
||||
else:
|
||||
discriminator_real_tensors = self.discriminator(target_tensor)
|
||||
discriminator_fake_tensors = self.discriminator(generator_output_tensor.detach())
|
||||
discriminator_loss = self.discriminator_loss(discriminator_real_tensors, discriminator_fake_tensors)
|
||||
|
||||
self.toggle_optimizer(generator_optimizer)
|
||||
self.manual_backward(generator_loss)
|
||||
@@ -176,9 +182,9 @@ class HyperSwapTrainer(LightningModule):
|
||||
|
||||
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
source_tensor, target_tensor = batch
|
||||
source_embedding = calculate_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0))
|
||||
source_embedding = calculate_face_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0))
|
||||
output_tensor, _ = self.forward(source_embedding, target_tensor)
|
||||
output_embedding = calculate_embedding(self.generator_embedder, output_tensor, (0, 0, 0, 0))
|
||||
output_embedding = calculate_face_embedding(self.generator_embedder, output_tensor, (0, 0, 0, 0))
|
||||
validation_score = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5
|
||||
self.log('validation_score', validation_score, sync_dist = True, prog_bar = True)
|
||||
return validation_score
|
||||
@@ -197,6 +203,13 @@ class HyperSwapTrainer(LightningModule):
|
||||
self.logger.experiment.add_image('preview', preview_grid, self.global_step) # type:ignore[attr-defined]
|
||||
|
||||
|
||||
class ModelWithConfigCheckpoint(ModelCheckpoint):
|
||||
def _save_checkpoint(self, trainer : Trainer, checkpoint_path : str) -> None:
|
||||
super()._save_checkpoint(trainer, checkpoint_path)
|
||||
config_path = Path(checkpoint_path).with_suffix('.ini')
|
||||
shutil.copy('config.ini', config_path)
|
||||
|
||||
|
||||
def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]:
|
||||
config_batch_size = CONFIG_PARSER.getint('training.loader', 'batch_size')
|
||||
config_num_workers = CONFIG_PARSER.getint('training.loader', 'num_workers')
|
||||
@@ -239,23 +252,24 @@ def prepare_datasets(config_parser : ConfigParser) -> List[Dataset[Tensor]]:
|
||||
|
||||
def create_trainer() -> Trainer:
|
||||
config_max_epochs = CONFIG_PARSER.getint('training.trainer', 'max_epochs')
|
||||
config_strategy = CONFIG_PARSER.get('training.trainer', 'strategy')
|
||||
config_precision = CONFIG_PARSER.get('training.trainer', 'precision')
|
||||
config_strategy = cast(TrainerStrategy, CONFIG_PARSER.get('training.trainer', 'strategy'))
|
||||
config_precision = cast(TrainerPrecision, CONFIG_PARSER.get('training.trainer', 'precision'))
|
||||
config_sync_batchnorm = CONFIG_PARSER.getboolean('training.trainer', 'sync_batchnorm')
|
||||
config_logger_path = CONFIG_PARSER.get('training.logger', 'logger_path')
|
||||
config_logger_name = CONFIG_PARSER.get('training.logger', 'logger_name')
|
||||
config_directory_path = CONFIG_PARSER.get('training.output', 'directory_path')
|
||||
config_file_pattern = CONFIG_PARSER.get('training.output', 'file_pattern')
|
||||
logger = TensorBoardLogger(config_logger_path, config_logger_name)
|
||||
|
||||
return Trainer(
|
||||
logger = logger,
|
||||
log_every_n_steps = 10,
|
||||
max_epochs = config_max_epochs,
|
||||
strategy = config_strategy,
|
||||
precision = config_precision,
|
||||
sync_batchnorm = config_sync_batchnorm,
|
||||
callbacks =
|
||||
[
|
||||
ModelCheckpoint(
|
||||
ModelWithConfigCheckpoint(
|
||||
monitor = 'generator_loss',
|
||||
dirpath = config_directory_path,
|
||||
filename = config_file_pattern,
|
||||
|
||||
@@ -23,3 +23,6 @@ GazerModule : TypeAlias = Module
|
||||
FaceMaskerModule : TypeAlias = Module
|
||||
|
||||
OptimizerSet : TypeAlias = Any
|
||||
|
||||
TrainerStrategy = Literal['auto', 'ddp', 'ddp_spawn', 'ddp_find_unused_parameters_true']
|
||||
TrainerPrecision = Literal['64-true', '32-true', '16-true', '16-mixed', 'bf16-true', 'bf16-mixed', 'transformer-engine', 'transformer-engine-float16']
|
||||
|
||||
@@ -16,7 +16,6 @@ def test_aad_with_unet(output_size : int) -> None:
|
||||
'training.model.generator':
|
||||
{
|
||||
'source_channels': '512',
|
||||
'output_channels': str(output_size * 16),
|
||||
'output_size': str(output_size),
|
||||
'num_blocks': '2'
|
||||
}
|
||||
|
||||
+4
-4
@@ -1,10 +1,10 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu128
|
||||
albumentations==2.0.8
|
||||
lightning==2.5.1
|
||||
lightning==2.5.5
|
||||
onnx==1.18.0
|
||||
onnxruntime==1.22.0
|
||||
pytorch-msssim==1.0.0
|
||||
torch==2.7.1
|
||||
torch==2.8.0
|
||||
torchdata==0.11.0
|
||||
torchvision==0.22.1
|
||||
tensorboard==2.19.0
|
||||
torchvision==0.23.0
|
||||
tensorboard==2.20.0
|
||||
|
||||
Reference in New Issue
Block a user