diff --git a/crossface/README.md b/crossface/README.md index 07e68b3..efc679e 100644 --- a/crossface/README.md +++ b/crossface/README.md @@ -32,7 +32,7 @@ file_pattern = .datasets/megaface/**/*.jpg ``` [training.loader] -batch_size = 256 +batch_size = 128 num_workers = 8 split_ratio = 0.95 ``` @@ -90,7 +90,7 @@ python train.py Launch the TensorBoard to monitor the training. ``` -tensorboard --logdir=.logs +tensorboard --logdir .logs ``` diff --git a/crossface/src/exporting.py b/crossface/src/exporting.py index 7b85fe3..a1cdc3d 100644 --- a/crossface/src/exporting.py +++ b/crossface/src/exporting.py @@ -17,7 +17,7 @@ def export() -> None: config_opset_version = CONFIG_PARSER.getint('exporting', 'opset_version') os.makedirs(config_directory_path, exist_ok = True) - model = CrossFaceTrainer.load_from_checkpoint(config_source_path, map_location ='cpu').eval() + model = CrossFaceTrainer.load_from_checkpoint(config_source_path, config_parser = CONFIG_PARSER, map_location = 'cpu').eval() model.ir_version = torch.tensor(config_ir_version) input_tensor = torch.randn(1, 512) torch.onnx.export(model, input_tensor, config_target_path, input_names = [ 'input' ], output_names = [ 'output' ], opset_version = config_opset_version) diff --git a/crossface/src/models/crossface.py b/crossface/src/models/crossface.py index 1edf564..c0bc43a 100644 --- a/crossface/src/models/crossface.py +++ b/crossface/src/models/crossface.py @@ -1,28 +1,37 @@ -import torch from torch import Tensor, nn class CrossFace(nn.Module): def __init__(self) -> None: super().__init__() - self.layers = self.create_layers() - self.leaky_relu = nn.LeakyReLU() + self.sequence = self.create_sequence() + self.linear = nn.Linear(512, 512) + self.apply(init_weight) @staticmethod - def create_layers() -> nn.ModuleList: - return nn.ModuleList( - [ + def create_sequence() -> nn.Sequential: + return nn.Sequential( nn.Linear(512, 1024), + nn.LayerNorm(1024), + nn.GELU(), + nn.Dropout(0.1), nn.Linear(1024, 2048), + nn.LayerNorm(2048), + nn.GELU(), + nn.Dropout(0.1), nn.Linear(2048, 1024), + nn.LayerNorm(1024), + nn.GELU(), + nn.Dropout(0.1), nn.Linear(1024, 512) - ]) + ) def forward(self, input_tensor : Tensor) -> Tensor: - output_tensor = input_tensor / torch.norm(input_tensor) + temp_tensor = nn.functional.normalize(input_tensor, p = 2, dim = -1) + return self.sequence(temp_tensor) + 0.2 * self.linear(temp_tensor) - for layer in self.layers[:-1]: - output_tensor = self.leaky_relu(layer(output_tensor)) - output_tensor = self.layers[-1](output_tensor) - return output_tensor +def init_weight(module : nn.Module) -> None: + if isinstance(module, nn.Linear): + nn.init.xavier_normal_(module.weight) + nn.init.constant_(module.bias, 0.01) diff --git a/crossface/src/training.py b/crossface/src/training.py index 6df45ca..7397090 100644 --- a/crossface/src/training.py +++ b/crossface/src/training.py @@ -1,10 +1,12 @@ import os +import shutil from configparser import ConfigParser -from typing import Tuple +from pathlib import Path +from typing import Tuple, cast import torch from lightning import LightningModule, Trainer -from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.callbacks import ModelCheckpoint, StochasticWeightAveraging from lightning.pytorch.loggers import TensorBoardLogger from torch import Tensor, nn from torch.utils.data import Dataset, random_split @@ -12,7 +14,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader from .dataset import StaticDataset from .models.crossface import CrossFace -from .types import Batch, Embedding, OptimizerSet +from .types import Batch, Embedding, OptimizerSet, TrainerPrecision, TrainerStrategy CONFIG_PARSER = ConfigParser() CONFIG_PARSER.read('config.ini') @@ -67,6 +69,13 @@ class CrossFaceTrainer(LightningModule): return optimizer_set +class ModelWithConfigCheckpoint(ModelCheckpoint): + def _save_checkpoint(self, trainer : Trainer, checkpoint_path : str) -> None: + super()._save_checkpoint(trainer, checkpoint_path) + config_path = Path(checkpoint_path).with_suffix('.ini') + shutil.copy('config.ini', config_path) + + def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]: config_batch_size = CONFIG_PARSER.getint('training.loader', 'batch_size') config_num_workers = CONFIG_PARSER.getint('training.loader', 'num_workers') @@ -89,8 +98,8 @@ def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[T def create_trainer() -> Trainer: config_max_epochs = CONFIG_PARSER.getint('training.trainer', 'max_epochs') - config_strategy = CONFIG_PARSER.get('training.trainer', 'strategy') - config_precision = CONFIG_PARSER.get('training.trainer', 'precision') + config_strategy = cast(TrainerStrategy, CONFIG_PARSER.get('training.trainer', 'strategy')) + config_precision = cast(TrainerPrecision, CONFIG_PARSER.get('training.trainer', 'precision')) config_logger_path = CONFIG_PARSER.get('training.logger', 'logger_path') config_logger_name = CONFIG_PARSER.get('training.logger', 'logger_name') config_directory_path = CONFIG_PARSER.get('training.output', 'directory_path') @@ -105,15 +114,17 @@ def create_trainer() -> Trainer: precision = config_precision, callbacks = [ - ModelCheckpoint( + ModelWithConfigCheckpoint( monitor = 'training_loss', dirpath = config_directory_path, filename = config_file_pattern, - every_n_epochs = 1, + every_n_epochs = 1000, save_top_k = 5, save_last = True - ) - ] + ), + StochasticWeightAveraging(swa_lrs = 1e-2) + ], + val_check_interval = 1000 ) diff --git a/crossface/src/types.py b/crossface/src/types.py index 0522b39..5a6699b 100644 --- a/crossface/src/types.py +++ b/crossface/src/types.py @@ -1,4 +1,4 @@ -from typing import Any, TypeAlias +from typing import Any, Literal, TypeAlias from torch import Tensor @@ -6,3 +6,6 @@ Batch : TypeAlias = Tensor Embedding : TypeAlias = Tensor OptimizerSet : TypeAlias = Any + +TrainerStrategy = Literal['auto', 'ddp', 'ddp_spawn', 'ddp_find_unused_parameters_true'] +TrainerPrecision = Literal['64-true', '32-true', '16-true', '16-mixed', 'bf16-true', 'bf16-mixed', 'transformer-engine', 'transformer-engine-float16'] diff --git a/hyperswap/README.md b/hyperswap/README.md index a25409a..1f6b7b3 100644 --- a/hyperswap/README.md +++ b/hyperswap/README.md @@ -54,7 +54,6 @@ face_masker_path = .models/face_masker.pt ``` [training.model.generator] source_channels = 512 -output_channels = 4096 output_size = 256 num_blocks = 2 ``` @@ -89,10 +88,12 @@ mask_weight = 5.0 ``` [training.trainer] accumulate_size = 4 +discriminator_ratio = 0.4 gradient_clip = 20.0 max_epochs = 50 strategy = auto precision = 16-mixed +sync_batchnorm = false preview_frequency = 100 ``` @@ -164,7 +165,7 @@ python train.py Launch the TensorBoard to monitor the training. ``` -tensorboard --logdir=.logs +tensorboard --logdir .logs ``` diff --git a/hyperswap/config.ini b/hyperswap/config.ini index 6935b29..f4bea66 100644 --- a/hyperswap/config.ini +++ b/hyperswap/config.ini @@ -1,7 +1,7 @@ [training.dataset] file_pattern = convert_template = -multiplier +multiplier = transform_size = usage_mode = batch_mode = @@ -20,7 +20,6 @@ face_masker_path = [training.model.generator] source_channels = -output_channels = output_size = num_blocks = @@ -47,10 +46,12 @@ mask_weight = [training.trainer] accumulate_size = +discriminator_ratio = gradient_clip = max_epochs = strategy = precision = +sync_batchnorm = preview_frequency = [training.modifier] diff --git a/hyperswap/src/dataset.py b/hyperswap/src/dataset.py index 3243717..385bab7 100644 --- a/hyperswap/src/dataset.py +++ b/hyperswap/src/dataset.py @@ -43,6 +43,35 @@ class DynamicDataset(Dataset[Tensor]): def __len__(self) -> int: return len(resolve_static_file_pattern(self.config_file_pattern)) + def prepare_equal_batch(self, source_path : str) -> Batch: + return self.create_batch(source_path, source_path, self.config_convert_template, self.config_convert_template) + + def prepare_same_batch(self, source_path : str) -> Batch: + target_directory_path = os.path.dirname(source_path) + target_file_name_and_extension = random.choice(os.listdir(target_directory_path)) + target_path = os.path.join(target_directory_path, target_file_name_and_extension) + return self.create_batch(source_path, target_path, self.config_convert_template, self.config_convert_template) + + def prepare_source_batch(self, source_path : str) -> Batch: + config_parser = self.filter_config_by_usage_mode('both') + config_section = random.choice(config_parser.sections()) + config_file_pattern = config_parser.get(config_section, 'file_pattern') + config_convert_template = cast(ConvertTemplate, config_parser.get(config_section, 'convert_template')) + target_path = random.choice(resolve_static_file_pattern(config_file_pattern)) + return self.create_batch(source_path, target_path, self.config_convert_template, config_convert_template) + + def prepare_target_batch(self, target_path : str) -> Batch: + config_parser = self.filter_config_by_usage_mode('both') + config_section = random.choice(config_parser.sections()) + config_file_pattern = config_parser.get(config_section, 'file_pattern') + config_convert_template = cast(ConvertTemplate, config_parser.get(config_section, 'convert_template')) + source_path = random.choice(resolve_static_file_pattern(config_file_pattern)) + return self.create_batch(source_path, target_path, config_convert_template, self.config_convert_template) + + def prepare_different_batch(self, source_path : str) -> Batch: + target_path = random.choice(resolve_static_file_pattern(self.config_file_pattern)) + return self.create_batch(source_path, target_path, self.config_convert_template, self.config_convert_template) + def compose_transforms(self) -> transforms: return transforms.Compose( [ @@ -53,64 +82,6 @@ class DynamicDataset(Dataset[Tensor]): transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) - def prepare_equal_batch(self, source_path : str) -> Batch: - source_tensor = io.read_image(source_path) - source_tensor = self.transforms(source_tensor) - source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template) - return source_tensor, source_tensor - - def prepare_same_batch(self, source_path : str) -> Batch: - target_directory_path = os.path.dirname(source_path) - target_file_name_and_extension = random.choice(os.listdir(target_directory_path)) - target_path = os.path.join(target_directory_path, target_file_name_and_extension) - source_tensor = io.read_image(source_path) - source_tensor = self.transforms(source_tensor) - source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template) - target_tensor = io.read_image(target_path) - target_tensor = self.transforms(target_tensor) - target_tensor = self.conditional_convert_tensor(target_tensor, self.config_convert_template) - return source_tensor, target_tensor - - def prepare_source_batch(self, source_path : str) -> Batch: - config_parser = self.filter_config_by_usage_mode('both') - config_section = random.choice(config_parser.sections()) - config_file_pattern = config_parser.get(config_section, 'file_pattern') - config_convert_template = cast(ConvertTemplate, config_parser.get(config_section, 'convert_template')) - - target_path = random.choice(resolve_static_file_pattern(config_file_pattern)) - source_tensor = io.read_image(source_path) - source_tensor = self.transforms(source_tensor) - source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template) - target_tensor = io.read_image(target_path) - target_tensor = self.transforms(target_tensor) - target_tensor = self.conditional_convert_tensor(target_tensor, config_convert_template) - return source_tensor, target_tensor - - def prepare_target_batch(self, target_path : str) -> Batch: - config_parser = self.filter_config_by_usage_mode('both') - config_section = random.choice(config_parser.sections()) - config_file_pattern = config_parser.get(config_section, 'file_pattern') - config_convert_template = cast(ConvertTemplate, config_parser.get(config_section, 'convert_template')) - - source_path = random.choice(resolve_static_file_pattern(config_file_pattern)) - source_tensor = io.read_image(source_path) - source_tensor = self.transforms(source_tensor) - source_tensor = self.conditional_convert_tensor(source_tensor, config_convert_template) - target_tensor = io.read_image(target_path) - target_tensor = self.transforms(target_tensor) - target_tensor = self.conditional_convert_tensor(target_tensor, self.config_convert_template) - return source_tensor, target_tensor - - def prepare_different_batch(self, source_path : str) -> Batch: - target_path = random.choice(resolve_static_file_pattern(self.config_file_pattern)) - source_tensor = io.read_image(source_path) - source_tensor = self.transforms(source_tensor) - source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template) - target_tensor = io.read_image(target_path) - target_tensor = self.transforms(target_tensor) - target_tensor = self.conditional_convert_tensor(target_tensor, self.config_convert_template) - return source_tensor, target_tensor - def filter_config_by_usage_mode(self, usage_mode : UsageMode) -> ConfigParser: config_parser = ConfigParser() @@ -126,6 +97,15 @@ class DynamicDataset(Dataset[Tensor]): return config_parser + def create_batch(self, source_path : str, target_path : str, source_convert_template : ConvertTemplate, target_convert_template : ConvertTemplate) -> Batch: + source_tensor = io.read_image(source_path) + source_tensor = self.transforms(source_tensor) + source_tensor = self.conditional_convert_tensor(source_tensor, source_convert_template) + target_tensor = io.read_image(target_path) + target_tensor = self.transforms(target_tensor) + target_tensor = self.conditional_convert_tensor(target_tensor, target_convert_template) + return source_tensor, target_tensor + @staticmethod def conditional_convert_tensor(input_tensor : Tensor, convert_template : ConvertTemplate) -> Tensor: if convert_template: diff --git a/hyperswap/src/exporting.py b/hyperswap/src/exporting.py index 20c5020..facdfa8 100644 --- a/hyperswap/src/exporting.py +++ b/hyperswap/src/exporting.py @@ -36,7 +36,7 @@ def export() -> None: config_precision = CONFIG_PARSER.get('exporting', 'precision') os.makedirs(config_directory_path, exist_ok = True) - model = HyperSwapTrainer.load_from_checkpoint(config_source_path, config_parser = CONFIG_PARSER, map_location ='cpu').eval() + model = HyperSwapTrainer.load_from_checkpoint(config_source_path, config_parser = CONFIG_PARSER, map_location = 'cpu').eval() if config_precision == 'half': model = HalfPrecision(model).eval() diff --git a/hyperswap/src/helper.py b/hyperswap/src/helper.py index 2abc70a..b498da4 100644 --- a/hyperswap/src/helper.py +++ b/hyperswap/src/helper.py @@ -34,7 +34,7 @@ def convert_tensor(input_tensor : Tensor, convert_template : ConvertTemplate) -> return output_tensor -def calculate_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : Padding) -> Embedding: +def calculate_face_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : Padding) -> Embedding: crop_tensor = convert_tensor(input_tensor, 'arcface_128_to_arcface_112_v2') crop_tensor = nn.functional.interpolate(crop_tensor, size = 112, mode = 'area') crop_tensor[:, :, :padding[0], :] = 0 @@ -42,9 +42,9 @@ def calculate_embedding(embedder : EmbedderModule, input_tensor : Tensor, paddin crop_tensor[:, :, :, :padding[2]] = 0 crop_tensor[:, :, :, 112 - padding[3]:] = 0 - embedding = embedder(crop_tensor) - embedding = nn.functional.normalize(embedding, p = 2) - return embedding + face_embedding = embedder(crop_tensor) + face_embedding = nn.functional.normalize(face_embedding, p = 2) + return face_embedding def overlay_mask(input_tensor : Tensor, input_mask : Mask) -> Tensor: @@ -56,7 +56,7 @@ def overlay_mask(input_tensor : Tensor, input_mask : Mask) -> Tensor: def dilate_mask(input_tensor : Tensor, factor : float) -> Tensor: - padding = round(input_tensor.shape[2] * factor) + padding = int(input_tensor.shape[2] * factor + 0.5) kernel_size = 1 + 2 * padding temp_tensor = nn.functional.pad(input_tensor, (padding, padding, padding, padding), mode = 'replicate') output_tensor = nn.functional.max_pool2d(temp_tensor, kernel_size = kernel_size, stride = 1, padding = 0) @@ -64,7 +64,7 @@ def dilate_mask(input_tensor : Tensor, factor : float) -> Tensor: def erode_mask(input_tensor : Tensor, factor : float) -> Tensor: - padding = round(input_tensor.shape[2] * factor) + padding = int(input_tensor.shape[2] * factor + 0.5) kernel_size = 1 + 2 * padding temp_tensor = 1 - nn.functional.pad(input_tensor, (padding, padding, padding, padding), mode = 'replicate') output_tensor = 1 - nn.functional.max_pool2d(temp_tensor, kernel_size = kernel_size, stride = 1, padding = 0) diff --git a/hyperswap/src/inferencing.py b/hyperswap/src/inferencing.py index 3271ccf..09c1f72 100644 --- a/hyperswap/src/inferencing.py +++ b/hyperswap/src/inferencing.py @@ -3,7 +3,7 @@ import configparser import torch from torchvision import io -from .helper import calculate_embedding +from .helper import calculate_face_embedding from .training import HyperSwapTrainer CONFIG_PARSER = configparser.ConfigParser() @@ -22,6 +22,6 @@ def infer() -> None: source_tensor = io.read_image(config_source_path) target_tensor = io.read_image(config_target_path) - source_embedding = calculate_embedding(embedder, source_tensor, (0, 0, 0, 0)) + source_embedding = calculate_face_embedding(embedder, source_tensor, (0, 0, 0, 0)) output_tensor, _ = generator(source_embedding, target_tensor) io.write_jpeg(output_tensor, config_output_path) diff --git a/hyperswap/src/models/loss.py b/hyperswap/src/models/loss.py index 45a32a2..06552de 100644 --- a/hyperswap/src/models/loss.py +++ b/hyperswap/src/models/loss.py @@ -6,7 +6,7 @@ from pytorch_msssim import ssim from torch import Tensor, nn from torchvision import transforms -from ..helper import calculate_embedding, dilate_mask +from ..helper import calculate_face_embedding, dilate_mask from ..types import EmbedderModule, FaceMaskerModule, Feature, GazerModule, Loss, Mask @@ -14,16 +14,16 @@ class DiscriminatorLoss(nn.Module): def __init__(self) -> None: super().__init__() - def forward(self, discriminator_source_tensors : List[Tensor], discriminator_output_tensors : List[Tensor]) -> Loss: + def forward(self, discriminator_real_tensors : List[Tensor], discriminator_fake_tensors : List[Tensor]) -> Loss: positive_tensors = [] negative_tensors = [] - for discriminator_source_tensor in discriminator_source_tensors: - positive_tensor = torch.relu(1 - discriminator_source_tensor).mean(dim = [ 1, 2, 3 ]) + for discriminator_real_tensor in discriminator_real_tensors: + positive_tensor = torch.relu(1 - discriminator_real_tensor).mean(dim = [ 1, 2, 3 ]) positive_tensors.append(positive_tensor) - for discriminator_output_tensor in discriminator_output_tensors: - negative_tensor = torch.relu(discriminator_output_tensor + 1).mean(dim = [ 1, 2, 3 ]) + for discriminator_fake_tensor in discriminator_fake_tensors: + negative_tensor = torch.relu(discriminator_fake_tensor + 1).mean(dim = [ 1, 2, 3 ]) negative_tensors.append(negative_tensor) positive_loss = torch.stack(positive_tensors).mean() @@ -97,8 +97,8 @@ class ReconstructionLoss(nn.Module): def forward(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss]: with torch.no_grad(): - source_embedding = calculate_embedding(self.embedder, source_tensor, (0, 0, 0, 0)) - target_embedding = calculate_embedding(self.embedder, target_tensor, (0, 0, 0, 0)) + source_embedding = calculate_face_embedding(self.embedder, source_tensor, (0, 0, 0, 0)) + target_embedding = calculate_face_embedding(self.embedder, target_tensor, (0, 0, 0, 0)) has_similar_identity = torch.cosine_similarity(source_embedding, target_embedding) > 0.8 @@ -120,8 +120,8 @@ class IdentityLoss(nn.Module): self.embedder = embedder def forward(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss]: - output_embedding = calculate_embedding(self.embedder, output_tensor, (30, 0, 10, 10)) - source_embedding = calculate_embedding(self.embedder, source_tensor, (30, 0, 10, 10)) + output_embedding = calculate_face_embedding(self.embedder, output_tensor, (30, 0, 10, 10)) + source_embedding = calculate_face_embedding(self.embedder, source_tensor, (30, 0, 10, 10)) identity_loss = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean() weighted_identity_loss = identity_loss * self.config_identity_weight return identity_loss, weighted_identity_loss diff --git a/hyperswap/src/networks/aad.py b/hyperswap/src/networks/aad.py index 10c1231..63fbb40 100644 --- a/hyperswap/src/networks/aad.py +++ b/hyperswap/src/networks/aad.py @@ -11,10 +11,9 @@ class AAD(nn.Module): def __init__(self, config_parser : ConfigParser) -> None: super().__init__() self.config_source_channels = config_parser.getint('training.model.generator', 'source_channels') - self.config_output_channels = config_parser.getint('training.model.generator', 'output_channels') self.config_output_size = config_parser.getint('training.model.generator', 'output_size') self.config_num_blocks = config_parser.getint('training.model.generator', 'num_blocks') - self.pixel_shuffle_up_sample = PixelShuffleUpSample(self.config_source_channels, self.config_output_channels) + self.pixel_shuffle_up_sample = PixelShuffleUpSample(self.config_source_channels, 4096) self.layers = self.create_layers() def create_layers(self) -> nn.ModuleList: @@ -23,9 +22,9 @@ class AAD(nn.Module): if self.config_output_size == 128: layers.extend( [ - AdaptiveFeatureModulation(512, 512, 512, self.config_source_channels, self.config_num_blocks), - AdaptiveFeatureModulation(512, 512, 1024, self.config_source_channels, self.config_num_blocks), - AdaptiveFeatureModulation(512, 512, 512, self.config_source_channels, self.config_num_blocks) + AdaptiveFeatureModulation(1024, 1024, 512, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(1024, 1024, 1024, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks) ]) if self.config_output_size == 256: @@ -40,21 +39,21 @@ class AAD(nn.Module): if self.config_output_size == 512: layers.extend( [ - AdaptiveFeatureModulation(2048, 2048, 2048, self.config_source_channels, self.config_num_blocks), - AdaptiveFeatureModulation(2048, 2048, 4096, self.config_source_channels, self.config_num_blocks), - AdaptiveFeatureModulation(2048, 2048, 2048, self.config_source_channels, self.config_num_blocks), - AdaptiveFeatureModulation(2048, 1024, 1024, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(1024, 1024, 1024, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(1024, 1024, 2048, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(1024, 1024, 1536, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(1024, 1024, 768, self.config_source_channels, self.config_num_blocks), AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks) ]) if self.config_output_size == 1024: layers.extend( [ - AdaptiveFeatureModulation(4096, 4096, 4096, self.config_source_channels, self.config_num_blocks), - AdaptiveFeatureModulation(4096, 4096, 8192, self.config_source_channels, self.config_num_blocks), - AdaptiveFeatureModulation(4096, 4096, 4096, self.config_source_channels, self.config_num_blocks), - AdaptiveFeatureModulation(4096, 2048, 2048, self.config_source_channels, self.config_num_blocks), - AdaptiveFeatureModulation(2048, 1024, 1024, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(1024, 1024, 2048, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(1024, 1024, 4096, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(1024, 1024, 3072, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(1024, 1024, 1536, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(1024, 1024, 768, self.config_source_channels, self.config_num_blocks), AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks) ]) @@ -100,7 +99,7 @@ class AdaptiveFeatureModulation(nn.Module): primary_layers.extend( [ FeatureModulation(self.context_input_channels, self.context_target_channels, self.context_source_channels), - nn.ReLU(inplace = True) + nn.ReLU() ]) if index < self.context_num_blocks - 1: @@ -117,7 +116,7 @@ class AdaptiveFeatureModulation(nn.Module): shortcut_layers.extend( [ FeatureModulation(self.context_input_channels, self.context_target_channels, self.context_source_channels), - nn.ReLU(inplace = True), + nn.ReLU(), nn.Conv2d(self.context_input_channels, self.context_output_channels, kernel_size = 3, padding = 1, bias = False) ]) diff --git a/hyperswap/src/networks/masknet.py b/hyperswap/src/networks/masknet.py index 023b767..166b1c6 100644 --- a/hyperswap/src/networks/masknet.py +++ b/hyperswap/src/networks/masknet.py @@ -56,17 +56,17 @@ class BottleNeck(nn.Module): def __init__(self, num_filters : int): super().__init__() self.sequences = self.create_sequences(num_filters) - self.relu = nn.ReLU(inplace = True) + self.relu = nn.ReLU() @staticmethod def create_sequences(num_filters : int) -> nn.Sequential: return nn.Sequential( nn.Conv2d(num_filters, num_filters, kernel_size = 3, padding = 1, bias = False), nn.BatchNorm2d(num_filters), - nn.ReLU(inplace = True), + nn.ReLU(), nn.Conv2d(num_filters, num_filters, kernel_size = 3, padding = 1, bias = False), nn.BatchNorm2d(num_filters), - nn.ReLU(inplace = True) + nn.ReLU() ) def forward(self, input_tensor : Tensor) -> Tensor: @@ -84,7 +84,7 @@ class UpSample(nn.Module): def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential: return nn.Sequential( nn.ConvTranspose2d(input_channels, output_channels, kernel_size = 2, stride = 2), - nn.ReLU(inplace = True) + nn.ReLU() ) def forward(self, input_tensor : Tensor) -> Tensor: @@ -102,7 +102,7 @@ class DownSample(nn.Module): return nn.Sequential( nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1, bias = False), nn.BatchNorm2d(output_channels), - nn.ReLU(inplace = True), + nn.ReLU(), nn.MaxPool2d(2) ) diff --git a/hyperswap/src/networks/nld.py b/hyperswap/src/networks/nld.py index cf2ffd9..e306d54 100644 --- a/hyperswap/src/networks/nld.py +++ b/hyperswap/src/networks/nld.py @@ -20,7 +20,7 @@ class NLD(nn.Module): layers = nn.ModuleList( [ nn.Conv2d(self.config_input_channels, current_filters, kernel_size = self.config_kernel_size, stride = 2, padding = padding), - nn.LeakyReLU(0.2, True) + nn.LeakyReLU(0.2) ]) for _ in range(1, self.config_num_layers): @@ -30,7 +30,7 @@ class NLD(nn.Module): [ nn.Conv2d(previous_filters, current_filters, kernel_size = self.config_kernel_size, stride = 2, padding = padding), nn.InstanceNorm2d(current_filters), - nn.LeakyReLU(0.2, True) + nn.LeakyReLU(0.2) ] previous_filters = current_filters @@ -39,7 +39,7 @@ class NLD(nn.Module): [ nn.Conv2d(previous_filters, current_filters, kernel_size = self.config_kernel_size, padding = padding), nn.InstanceNorm2d(current_filters), - nn.LeakyReLU(0.2, True), + nn.LeakyReLU(0.2), nn.Conv2d(current_filters, 1, kernel_size = self.config_kernel_size, padding = padding) ] return layers diff --git a/hyperswap/src/networks/unet.py b/hyperswap/src/networks/unet.py index 5d63561..91c08ca 100644 --- a/hyperswap/src/networks/unet.py +++ b/hyperswap/src/networks/unet.py @@ -41,8 +41,8 @@ class UNet(nn.Module): down_samples.extend( [ DownSample(512, 1024), - DownSample(1024, 2048), - DownSample(2048, 2048) + DownSample(1024, 1024), + DownSample(1024, 1024) ]) if self.config_output_size == 1024: @@ -50,8 +50,8 @@ class UNet(nn.Module): [ DownSample(512, 1024), DownSample(1024, 2048), - DownSample(2048, 4096), - DownSample(4096, 4096) + DownSample(2048, 2048), + DownSample(2048, 2048) ]) return down_samples @@ -62,36 +62,39 @@ class UNet(nn.Module): if self.config_output_size == 128: up_samples.extend( [ - UpSample(512, 512) + UpSample(512, 512), + UpSample(1024, 256) ]) if self.config_output_size == 256: up_samples.extend( [ UpSample(1024, 1024), - UpSample(2048, 512) + UpSample(2048, 512), + UpSample(1024, 256) ]) if self.config_output_size == 512: up_samples.extend( [ - UpSample(2048, 2048), - UpSample(4096, 1024), - UpSample(2048, 512) + UpSample(1024, 1024), + UpSample(2048, 512), + UpSample(1536, 256), + UpSample(768, 256) ]) if self.config_output_size == 1024: up_samples.extend( [ - UpSample(4096, 4096), - UpSample(8192, 2048), + UpSample(2048, 2048), UpSample(4096, 1024), - UpSample(2048, 512) + UpSample(3072, 512), + UpSample(1536, 256), + UpSample(768, 256) ]) up_samples.extend( [ - UpSample(1024, 256), UpSample(512, 128), UpSample(256, 64), UpSample(128, 32) @@ -130,7 +133,7 @@ class UpSample(nn.Module): return nn.Sequential( nn.ConvTranspose2d(input_channels, output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False), nn.BatchNorm2d(output_channels), - nn.LeakyReLU(0.1, inplace = True) + nn.LeakyReLU(0.1) ) def forward(self, input_tensor : Tensor, skip_tensor : Tensor) -> Tensor: @@ -149,7 +152,7 @@ class DownSample(nn.Module): return nn.Sequential( nn.Conv2d(input_channels, output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False), nn.BatchNorm2d(output_channels), - nn.LeakyReLU(0.1, inplace = True) + nn.LeakyReLU(0.1) ) def forward(self, input_tensor : Tensor) -> Tensor: diff --git a/hyperswap/src/training.py b/hyperswap/src/training.py index f647ab9..7b89366 100644 --- a/hyperswap/src/training.py +++ b/hyperswap/src/training.py @@ -1,8 +1,10 @@ import os +import shutil import warnings from configparser import ConfigParser from copy import deepcopy -from typing import List, Tuple +from pathlib import Path +from typing import List, Tuple, cast import torch import torchvision @@ -14,11 +16,11 @@ from torch.utils.data import ConcatDataset, Dataset, random_split from torchdata.stateful_dataloader import StatefulDataLoader from .dataset import DynamicDataset -from .helper import apply_noise, calculate_embedding, erode_mask, overlay_mask +from .helper import apply_noise, calculate_face_embedding, erode_mask, overlay_mask from .models.discriminator import Discriminator from .models.generator import Generator from .models.loss import AdversarialLoss, CycleLoss, DiscriminatorLoss, FeatureLoss, GazeLoss, IdentityLoss, MaskLoss, ReconstructionLoss -from .types import Batch, Embedding, Mask, OptimizerSet +from .types import Batch, Embedding, Mask, OptimizerSet, TrainerPrecision, TrainerStrategy warnings.filterwarnings('ignore', category = UserWarning, module = 'torch') @@ -34,6 +36,7 @@ class HyperSwapTrainer(LightningModule): self.config_gazer_path = config_parser.get('training.model', 'gazer_path') self.config_face_masker_path = config_parser.get('training.model', 'face_masker_path') self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size') + self.config_discriminator_ratio = config_parser.getfloat('training.trainer', 'discriminator_ratio') self.config_gradient_clip = config_parser.getfloat('training.trainer', 'gradient_clip') self.config_preview_frequency = config_parser.getint('training.trainer', 'preview_frequency') self.config_mask_factor = config_parser.getfloat('training.modifier', 'mask_factor') @@ -101,8 +104,8 @@ class HyperSwapTrainer(LightningModule): do_update = (batch_index + 1) % self.config_accumulate_size == 0 generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined] generator_scheduler, discriminator_scheduler = self.lr_schedulers() #type:ignore[attr-defined] - source_embedding = calculate_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0)) - target_embedding = calculate_embedding(self.generator_embedder, target_tensor, (0, 0, 0, 0)) + source_embedding = calculate_face_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0)) + target_embedding = calculate_face_embedding(self.generator_embedder, target_tensor, (0, 0, 0, 0)) if self.config_noise_factor > 0: source_embedding = apply_noise(source_embedding, self.config_noise_factor) @@ -123,9 +126,12 @@ class HyperSwapTrainer(LightningModule): mask_loss, weighted_mask_loss = self.mask_loss(target_tensor, generator_output_mask) generator_loss = weighted_adversarial_loss + weighted_cycle_loss + weighted_feature_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_gaze_loss + weighted_mask_loss - discriminator_source_tensors = self.discriminator(source_tensor) - discriminator_output_tensors = self.discriminator(generator_output_tensor.detach()) - discriminator_loss = self.discriminator_loss(discriminator_source_tensors, discriminator_output_tensors) + if torch.randn(1).item() < self.config_discriminator_ratio: + discriminator_real_tensors = self.discriminator(source_tensor) + else: + discriminator_real_tensors = self.discriminator(target_tensor) + discriminator_fake_tensors = self.discriminator(generator_output_tensor.detach()) + discriminator_loss = self.discriminator_loss(discriminator_real_tensors, discriminator_fake_tensors) self.toggle_optimizer(generator_optimizer) self.manual_backward(generator_loss) @@ -176,9 +182,9 @@ class HyperSwapTrainer(LightningModule): def validation_step(self, batch : Batch, batch_index : int) -> Tensor: source_tensor, target_tensor = batch - source_embedding = calculate_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0)) + source_embedding = calculate_face_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0)) output_tensor, _ = self.forward(source_embedding, target_tensor) - output_embedding = calculate_embedding(self.generator_embedder, output_tensor, (0, 0, 0, 0)) + output_embedding = calculate_face_embedding(self.generator_embedder, output_tensor, (0, 0, 0, 0)) validation_score = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5 self.log('validation_score', validation_score, sync_dist = True, prog_bar = True) return validation_score @@ -197,6 +203,13 @@ class HyperSwapTrainer(LightningModule): self.logger.experiment.add_image('preview', preview_grid, self.global_step) # type:ignore[attr-defined] +class ModelWithConfigCheckpoint(ModelCheckpoint): + def _save_checkpoint(self, trainer : Trainer, checkpoint_path : str) -> None: + super()._save_checkpoint(trainer, checkpoint_path) + config_path = Path(checkpoint_path).with_suffix('.ini') + shutil.copy('config.ini', config_path) + + def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]: config_batch_size = CONFIG_PARSER.getint('training.loader', 'batch_size') config_num_workers = CONFIG_PARSER.getint('training.loader', 'num_workers') @@ -239,23 +252,24 @@ def prepare_datasets(config_parser : ConfigParser) -> List[Dataset[Tensor]]: def create_trainer() -> Trainer: config_max_epochs = CONFIG_PARSER.getint('training.trainer', 'max_epochs') - config_strategy = CONFIG_PARSER.get('training.trainer', 'strategy') - config_precision = CONFIG_PARSER.get('training.trainer', 'precision') + config_strategy = cast(TrainerStrategy, CONFIG_PARSER.get('training.trainer', 'strategy')) + config_precision = cast(TrainerPrecision, CONFIG_PARSER.get('training.trainer', 'precision')) + config_sync_batchnorm = CONFIG_PARSER.getboolean('training.trainer', 'sync_batchnorm') config_logger_path = CONFIG_PARSER.get('training.logger', 'logger_path') config_logger_name = CONFIG_PARSER.get('training.logger', 'logger_name') config_directory_path = CONFIG_PARSER.get('training.output', 'directory_path') config_file_pattern = CONFIG_PARSER.get('training.output', 'file_pattern') logger = TensorBoardLogger(config_logger_path, config_logger_name) - return Trainer( logger = logger, log_every_n_steps = 10, max_epochs = config_max_epochs, strategy = config_strategy, precision = config_precision, + sync_batchnorm = config_sync_batchnorm, callbacks = [ - ModelCheckpoint( + ModelWithConfigCheckpoint( monitor = 'generator_loss', dirpath = config_directory_path, filename = config_file_pattern, diff --git a/hyperswap/src/types.py b/hyperswap/src/types.py index 1ce949a..ffc597c 100644 --- a/hyperswap/src/types.py +++ b/hyperswap/src/types.py @@ -23,3 +23,6 @@ GazerModule : TypeAlias = Module FaceMaskerModule : TypeAlias = Module OptimizerSet : TypeAlias = Any + +TrainerStrategy = Literal['auto', 'ddp', 'ddp_spawn', 'ddp_find_unused_parameters_true'] +TrainerPrecision = Literal['64-true', '32-true', '16-true', '16-mixed', 'bf16-true', 'bf16-mixed', 'transformer-engine', 'transformer-engine-float16'] diff --git a/hyperswap/tests/test_networks.py b/hyperswap/tests/test_networks.py index 9f52bbb..3213f1c 100644 --- a/hyperswap/tests/test_networks.py +++ b/hyperswap/tests/test_networks.py @@ -16,7 +16,6 @@ def test_aad_with_unet(output_size : int) -> None: 'training.model.generator': { 'source_channels': '512', - 'output_channels': str(output_size * 16), 'output_size': str(output_size), 'num_blocks': '2' } diff --git a/requirements.txt b/requirements.txt index c8e44c4..74d4399 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ --extra-index-url https://download.pytorch.org/whl/cu128 albumentations==2.0.8 -lightning==2.5.1 +lightning==2.5.5 onnx==1.18.0 onnxruntime==1.22.0 pytorch-msssim==1.0.0 -torch==2.7.1 +torch==2.8.0 torchdata==0.11.0 -torchvision==0.22.1 -tensorboard==2.19.0 +torchvision==0.23.0 +tensorboard==2.20.0