* add sync batchnorm

* replace random.choice with hash

* fifty percent reduction

* fix discriminator input

* restore dataset.py

* Remove duplicates

* add discriminator_ratio to config

* fix onnx export bug: replace round() with int()

* Fix embedding naming

* Introduce ModelWithConfigCheckpoint callback (#86)

* Fix dist ini

* Style: Refactor typing and improve code clarity in training.py (#88)

* Add type casting for trainer params

* Add type casting for trainer params

* Add type casting for trainer params

* Remove inplace activations for torch.compile compatibility (#89)

* Fix README

* improvise with norm layers & weighted average

* add skip layer

* use gelu instead of leaky_relu

* cleanup

* cleanup

* Update dependencies

* Different defaults and enable validation

* Different defaults and enable validation

* Revert to higher batch size

* Just use copy over copy2

---------

Co-authored-by: harisreedhar <h4harisreedhar.s.s@gmail.com>
Co-authored-by: NeuroDonu <112660822+NeuroDonu@users.noreply.github.com>
Co-authored-by: Harisreedhar <46858047+harisreedhar@users.noreply.github.com>
This commit is contained in:
Henry Ruhs
2025-09-06 19:12:29 +02:00
committed by GitHub
parent 9f9f9dbad7
commit 2e6394565a
20 changed files with 187 additions and 164 deletions
+2 -2
View File
@@ -32,7 +32,7 @@ file_pattern = .datasets/megaface/**/*.jpg
```
[training.loader]
batch_size = 256
batch_size = 128
num_workers = 8
split_ratio = 0.95
```
@@ -90,7 +90,7 @@ python train.py
Launch the TensorBoard to monitor the training.
```
tensorboard --logdir=.logs
tensorboard --logdir .logs
```
+1 -1
View File
@@ -17,7 +17,7 @@ def export() -> None:
config_opset_version = CONFIG_PARSER.getint('exporting', 'opset_version')
os.makedirs(config_directory_path, exist_ok = True)
model = CrossFaceTrainer.load_from_checkpoint(config_source_path, map_location ='cpu').eval()
model = CrossFaceTrainer.load_from_checkpoint(config_source_path, config_parser = CONFIG_PARSER, map_location = 'cpu').eval()
model.ir_version = torch.tensor(config_ir_version)
input_tensor = torch.randn(1, 512)
torch.onnx.export(model, input_tensor, config_target_path, input_names = [ 'input' ], output_names = [ 'output' ], opset_version = config_opset_version)
+21 -12
View File
@@ -1,28 +1,37 @@
import torch
from torch import Tensor, nn
class CrossFace(nn.Module):
def __init__(self) -> None:
super().__init__()
self.layers = self.create_layers()
self.leaky_relu = nn.LeakyReLU()
self.sequence = self.create_sequence()
self.linear = nn.Linear(512, 512)
self.apply(init_weight)
@staticmethod
def create_layers() -> nn.ModuleList:
return nn.ModuleList(
[
def create_sequence() -> nn.Sequential:
return nn.Sequential(
nn.Linear(512, 1024),
nn.LayerNorm(1024),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(1024, 2048),
nn.LayerNorm(2048),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(2048, 1024),
nn.LayerNorm(1024),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(1024, 512)
])
)
def forward(self, input_tensor : Tensor) -> Tensor:
output_tensor = input_tensor / torch.norm(input_tensor)
temp_tensor = nn.functional.normalize(input_tensor, p = 2, dim = -1)
return self.sequence(temp_tensor) + 0.2 * self.linear(temp_tensor)
for layer in self.layers[:-1]:
output_tensor = self.leaky_relu(layer(output_tensor))
output_tensor = self.layers[-1](output_tensor)
return output_tensor
def init_weight(module : nn.Module) -> None:
if isinstance(module, nn.Linear):
nn.init.xavier_normal_(module.weight)
nn.init.constant_(module.bias, 0.01)
+20 -9
View File
@@ -1,10 +1,12 @@
import os
import shutil
from configparser import ConfigParser
from typing import Tuple
from pathlib import Path
from typing import Tuple, cast
import torch
from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import ModelCheckpoint, StochasticWeightAveraging
from lightning.pytorch.loggers import TensorBoardLogger
from torch import Tensor, nn
from torch.utils.data import Dataset, random_split
@@ -12,7 +14,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader
from .dataset import StaticDataset
from .models.crossface import CrossFace
from .types import Batch, Embedding, OptimizerSet
from .types import Batch, Embedding, OptimizerSet, TrainerPrecision, TrainerStrategy
CONFIG_PARSER = ConfigParser()
CONFIG_PARSER.read('config.ini')
@@ -67,6 +69,13 @@ class CrossFaceTrainer(LightningModule):
return optimizer_set
class ModelWithConfigCheckpoint(ModelCheckpoint):
def _save_checkpoint(self, trainer : Trainer, checkpoint_path : str) -> None:
super()._save_checkpoint(trainer, checkpoint_path)
config_path = Path(checkpoint_path).with_suffix('.ini')
shutil.copy('config.ini', config_path)
def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]:
config_batch_size = CONFIG_PARSER.getint('training.loader', 'batch_size')
config_num_workers = CONFIG_PARSER.getint('training.loader', 'num_workers')
@@ -89,8 +98,8 @@ def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[T
def create_trainer() -> Trainer:
config_max_epochs = CONFIG_PARSER.getint('training.trainer', 'max_epochs')
config_strategy = CONFIG_PARSER.get('training.trainer', 'strategy')
config_precision = CONFIG_PARSER.get('training.trainer', 'precision')
config_strategy = cast(TrainerStrategy, CONFIG_PARSER.get('training.trainer', 'strategy'))
config_precision = cast(TrainerPrecision, CONFIG_PARSER.get('training.trainer', 'precision'))
config_logger_path = CONFIG_PARSER.get('training.logger', 'logger_path')
config_logger_name = CONFIG_PARSER.get('training.logger', 'logger_name')
config_directory_path = CONFIG_PARSER.get('training.output', 'directory_path')
@@ -105,15 +114,17 @@ def create_trainer() -> Trainer:
precision = config_precision,
callbacks =
[
ModelCheckpoint(
ModelWithConfigCheckpoint(
monitor = 'training_loss',
dirpath = config_directory_path,
filename = config_file_pattern,
every_n_epochs = 1,
every_n_epochs = 1000,
save_top_k = 5,
save_last = True
)
]
),
StochasticWeightAveraging(swa_lrs = 1e-2)
],
val_check_interval = 1000
)
+4 -1
View File
@@ -1,4 +1,4 @@
from typing import Any, TypeAlias
from typing import Any, Literal, TypeAlias
from torch import Tensor
@@ -6,3 +6,6 @@ Batch : TypeAlias = Tensor
Embedding : TypeAlias = Tensor
OptimizerSet : TypeAlias = Any
TrainerStrategy = Literal['auto', 'ddp', 'ddp_spawn', 'ddp_find_unused_parameters_true']
TrainerPrecision = Literal['64-true', '32-true', '16-true', '16-mixed', 'bf16-true', 'bf16-mixed', 'transformer-engine', 'transformer-engine-float16']
+3 -2
View File
@@ -54,7 +54,6 @@ face_masker_path = .models/face_masker.pt
```
[training.model.generator]
source_channels = 512
output_channels = 4096
output_size = 256
num_blocks = 2
```
@@ -89,10 +88,12 @@ mask_weight = 5.0
```
[training.trainer]
accumulate_size = 4
discriminator_ratio = 0.4
gradient_clip = 20.0
max_epochs = 50
strategy = auto
precision = 16-mixed
sync_batchnorm = false
preview_frequency = 100
```
@@ -164,7 +165,7 @@ python train.py
Launch the TensorBoard to monitor the training.
```
tensorboard --logdir=.logs
tensorboard --logdir .logs
```
+3 -2
View File
@@ -1,7 +1,7 @@
[training.dataset]
file_pattern =
convert_template =
multiplier
multiplier =
transform_size =
usage_mode =
batch_mode =
@@ -20,7 +20,6 @@ face_masker_path =
[training.model.generator]
source_channels =
output_channels =
output_size =
num_blocks =
@@ -47,10 +46,12 @@ mask_weight =
[training.trainer]
accumulate_size =
discriminator_ratio =
gradient_clip =
max_epochs =
strategy =
precision =
sync_batchnorm =
preview_frequency =
[training.modifier]
+38 -58
View File
@@ -43,6 +43,35 @@ class DynamicDataset(Dataset[Tensor]):
def __len__(self) -> int:
return len(resolve_static_file_pattern(self.config_file_pattern))
def prepare_equal_batch(self, source_path : str) -> Batch:
return self.create_batch(source_path, source_path, self.config_convert_template, self.config_convert_template)
def prepare_same_batch(self, source_path : str) -> Batch:
target_directory_path = os.path.dirname(source_path)
target_file_name_and_extension = random.choice(os.listdir(target_directory_path))
target_path = os.path.join(target_directory_path, target_file_name_and_extension)
return self.create_batch(source_path, target_path, self.config_convert_template, self.config_convert_template)
def prepare_source_batch(self, source_path : str) -> Batch:
config_parser = self.filter_config_by_usage_mode('both')
config_section = random.choice(config_parser.sections())
config_file_pattern = config_parser.get(config_section, 'file_pattern')
config_convert_template = cast(ConvertTemplate, config_parser.get(config_section, 'convert_template'))
target_path = random.choice(resolve_static_file_pattern(config_file_pattern))
return self.create_batch(source_path, target_path, self.config_convert_template, config_convert_template)
def prepare_target_batch(self, target_path : str) -> Batch:
config_parser = self.filter_config_by_usage_mode('both')
config_section = random.choice(config_parser.sections())
config_file_pattern = config_parser.get(config_section, 'file_pattern')
config_convert_template = cast(ConvertTemplate, config_parser.get(config_section, 'convert_template'))
source_path = random.choice(resolve_static_file_pattern(config_file_pattern))
return self.create_batch(source_path, target_path, config_convert_template, self.config_convert_template)
def prepare_different_batch(self, source_path : str) -> Batch:
target_path = random.choice(resolve_static_file_pattern(self.config_file_pattern))
return self.create_batch(source_path, target_path, self.config_convert_template, self.config_convert_template)
def compose_transforms(self) -> transforms:
return transforms.Compose(
[
@@ -53,64 +82,6 @@ class DynamicDataset(Dataset[Tensor]):
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
def prepare_equal_batch(self, source_path : str) -> Batch:
source_tensor = io.read_image(source_path)
source_tensor = self.transforms(source_tensor)
source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template)
return source_tensor, source_tensor
def prepare_same_batch(self, source_path : str) -> Batch:
target_directory_path = os.path.dirname(source_path)
target_file_name_and_extension = random.choice(os.listdir(target_directory_path))
target_path = os.path.join(target_directory_path, target_file_name_and_extension)
source_tensor = io.read_image(source_path)
source_tensor = self.transforms(source_tensor)
source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template)
target_tensor = io.read_image(target_path)
target_tensor = self.transforms(target_tensor)
target_tensor = self.conditional_convert_tensor(target_tensor, self.config_convert_template)
return source_tensor, target_tensor
def prepare_source_batch(self, source_path : str) -> Batch:
config_parser = self.filter_config_by_usage_mode('both')
config_section = random.choice(config_parser.sections())
config_file_pattern = config_parser.get(config_section, 'file_pattern')
config_convert_template = cast(ConvertTemplate, config_parser.get(config_section, 'convert_template'))
target_path = random.choice(resolve_static_file_pattern(config_file_pattern))
source_tensor = io.read_image(source_path)
source_tensor = self.transforms(source_tensor)
source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template)
target_tensor = io.read_image(target_path)
target_tensor = self.transforms(target_tensor)
target_tensor = self.conditional_convert_tensor(target_tensor, config_convert_template)
return source_tensor, target_tensor
def prepare_target_batch(self, target_path : str) -> Batch:
config_parser = self.filter_config_by_usage_mode('both')
config_section = random.choice(config_parser.sections())
config_file_pattern = config_parser.get(config_section, 'file_pattern')
config_convert_template = cast(ConvertTemplate, config_parser.get(config_section, 'convert_template'))
source_path = random.choice(resolve_static_file_pattern(config_file_pattern))
source_tensor = io.read_image(source_path)
source_tensor = self.transforms(source_tensor)
source_tensor = self.conditional_convert_tensor(source_tensor, config_convert_template)
target_tensor = io.read_image(target_path)
target_tensor = self.transforms(target_tensor)
target_tensor = self.conditional_convert_tensor(target_tensor, self.config_convert_template)
return source_tensor, target_tensor
def prepare_different_batch(self, source_path : str) -> Batch:
target_path = random.choice(resolve_static_file_pattern(self.config_file_pattern))
source_tensor = io.read_image(source_path)
source_tensor = self.transforms(source_tensor)
source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template)
target_tensor = io.read_image(target_path)
target_tensor = self.transforms(target_tensor)
target_tensor = self.conditional_convert_tensor(target_tensor, self.config_convert_template)
return source_tensor, target_tensor
def filter_config_by_usage_mode(self, usage_mode : UsageMode) -> ConfigParser:
config_parser = ConfigParser()
@@ -126,6 +97,15 @@ class DynamicDataset(Dataset[Tensor]):
return config_parser
def create_batch(self, source_path : str, target_path : str, source_convert_template : ConvertTemplate, target_convert_template : ConvertTemplate) -> Batch:
source_tensor = io.read_image(source_path)
source_tensor = self.transforms(source_tensor)
source_tensor = self.conditional_convert_tensor(source_tensor, source_convert_template)
target_tensor = io.read_image(target_path)
target_tensor = self.transforms(target_tensor)
target_tensor = self.conditional_convert_tensor(target_tensor, target_convert_template)
return source_tensor, target_tensor
@staticmethod
def conditional_convert_tensor(input_tensor : Tensor, convert_template : ConvertTemplate) -> Tensor:
if convert_template:
+1 -1
View File
@@ -36,7 +36,7 @@ def export() -> None:
config_precision = CONFIG_PARSER.get('exporting', 'precision')
os.makedirs(config_directory_path, exist_ok = True)
model = HyperSwapTrainer.load_from_checkpoint(config_source_path, config_parser = CONFIG_PARSER, map_location ='cpu').eval()
model = HyperSwapTrainer.load_from_checkpoint(config_source_path, config_parser = CONFIG_PARSER, map_location = 'cpu').eval()
if config_precision == 'half':
model = HalfPrecision(model).eval()
+6 -6
View File
@@ -34,7 +34,7 @@ def convert_tensor(input_tensor : Tensor, convert_template : ConvertTemplate) ->
return output_tensor
def calculate_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : Padding) -> Embedding:
def calculate_face_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : Padding) -> Embedding:
crop_tensor = convert_tensor(input_tensor, 'arcface_128_to_arcface_112_v2')
crop_tensor = nn.functional.interpolate(crop_tensor, size = 112, mode = 'area')
crop_tensor[:, :, :padding[0], :] = 0
@@ -42,9 +42,9 @@ def calculate_embedding(embedder : EmbedderModule, input_tensor : Tensor, paddin
crop_tensor[:, :, :, :padding[2]] = 0
crop_tensor[:, :, :, 112 - padding[3]:] = 0
embedding = embedder(crop_tensor)
embedding = nn.functional.normalize(embedding, p = 2)
return embedding
face_embedding = embedder(crop_tensor)
face_embedding = nn.functional.normalize(face_embedding, p = 2)
return face_embedding
def overlay_mask(input_tensor : Tensor, input_mask : Mask) -> Tensor:
@@ -56,7 +56,7 @@ def overlay_mask(input_tensor : Tensor, input_mask : Mask) -> Tensor:
def dilate_mask(input_tensor : Tensor, factor : float) -> Tensor:
padding = round(input_tensor.shape[2] * factor)
padding = int(input_tensor.shape[2] * factor + 0.5)
kernel_size = 1 + 2 * padding
temp_tensor = nn.functional.pad(input_tensor, (padding, padding, padding, padding), mode = 'replicate')
output_tensor = nn.functional.max_pool2d(temp_tensor, kernel_size = kernel_size, stride = 1, padding = 0)
@@ -64,7 +64,7 @@ def dilate_mask(input_tensor : Tensor, factor : float) -> Tensor:
def erode_mask(input_tensor : Tensor, factor : float) -> Tensor:
padding = round(input_tensor.shape[2] * factor)
padding = int(input_tensor.shape[2] * factor + 0.5)
kernel_size = 1 + 2 * padding
temp_tensor = 1 - nn.functional.pad(input_tensor, (padding, padding, padding, padding), mode = 'replicate')
output_tensor = 1 - nn.functional.max_pool2d(temp_tensor, kernel_size = kernel_size, stride = 1, padding = 0)
+2 -2
View File
@@ -3,7 +3,7 @@ import configparser
import torch
from torchvision import io
from .helper import calculate_embedding
from .helper import calculate_face_embedding
from .training import HyperSwapTrainer
CONFIG_PARSER = configparser.ConfigParser()
@@ -22,6 +22,6 @@ def infer() -> None:
source_tensor = io.read_image(config_source_path)
target_tensor = io.read_image(config_target_path)
source_embedding = calculate_embedding(embedder, source_tensor, (0, 0, 0, 0))
source_embedding = calculate_face_embedding(embedder, source_tensor, (0, 0, 0, 0))
output_tensor, _ = generator(source_embedding, target_tensor)
io.write_jpeg(output_tensor, config_output_path)
+10 -10
View File
@@ -6,7 +6,7 @@ from pytorch_msssim import ssim
from torch import Tensor, nn
from torchvision import transforms
from ..helper import calculate_embedding, dilate_mask
from ..helper import calculate_face_embedding, dilate_mask
from ..types import EmbedderModule, FaceMaskerModule, Feature, GazerModule, Loss, Mask
@@ -14,16 +14,16 @@ class DiscriminatorLoss(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, discriminator_source_tensors : List[Tensor], discriminator_output_tensors : List[Tensor]) -> Loss:
def forward(self, discriminator_real_tensors : List[Tensor], discriminator_fake_tensors : List[Tensor]) -> Loss:
positive_tensors = []
negative_tensors = []
for discriminator_source_tensor in discriminator_source_tensors:
positive_tensor = torch.relu(1 - discriminator_source_tensor).mean(dim = [ 1, 2, 3 ])
for discriminator_real_tensor in discriminator_real_tensors:
positive_tensor = torch.relu(1 - discriminator_real_tensor).mean(dim = [ 1, 2, 3 ])
positive_tensors.append(positive_tensor)
for discriminator_output_tensor in discriminator_output_tensors:
negative_tensor = torch.relu(discriminator_output_tensor + 1).mean(dim = [ 1, 2, 3 ])
for discriminator_fake_tensor in discriminator_fake_tensors:
negative_tensor = torch.relu(discriminator_fake_tensor + 1).mean(dim = [ 1, 2, 3 ])
negative_tensors.append(negative_tensor)
positive_loss = torch.stack(positive_tensors).mean()
@@ -97,8 +97,8 @@ class ReconstructionLoss(nn.Module):
def forward(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss]:
with torch.no_grad():
source_embedding = calculate_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
target_embedding = calculate_embedding(self.embedder, target_tensor, (0, 0, 0, 0))
source_embedding = calculate_face_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
target_embedding = calculate_face_embedding(self.embedder, target_tensor, (0, 0, 0, 0))
has_similar_identity = torch.cosine_similarity(source_embedding, target_embedding) > 0.8
@@ -120,8 +120,8 @@ class IdentityLoss(nn.Module):
self.embedder = embedder
def forward(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss]:
output_embedding = calculate_embedding(self.embedder, output_tensor, (30, 0, 10, 10))
source_embedding = calculate_embedding(self.embedder, source_tensor, (30, 0, 10, 10))
output_embedding = calculate_face_embedding(self.embedder, output_tensor, (30, 0, 10, 10))
source_embedding = calculate_face_embedding(self.embedder, source_tensor, (30, 0, 10, 10))
identity_loss = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean()
weighted_identity_loss = identity_loss * self.config_identity_weight
return identity_loss, weighted_identity_loss
+15 -16
View File
@@ -11,10 +11,9 @@ class AAD(nn.Module):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config_source_channels = config_parser.getint('training.model.generator', 'source_channels')
self.config_output_channels = config_parser.getint('training.model.generator', 'output_channels')
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
self.config_num_blocks = config_parser.getint('training.model.generator', 'num_blocks')
self.pixel_shuffle_up_sample = PixelShuffleUpSample(self.config_source_channels, self.config_output_channels)
self.pixel_shuffle_up_sample = PixelShuffleUpSample(self.config_source_channels, 4096)
self.layers = self.create_layers()
def create_layers(self) -> nn.ModuleList:
@@ -23,9 +22,9 @@ class AAD(nn.Module):
if self.config_output_size == 128:
layers.extend(
[
AdaptiveFeatureModulation(512, 512, 512, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(512, 512, 1024, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(512, 512, 512, self.config_source_channels, self.config_num_blocks)
AdaptiveFeatureModulation(1024, 1024, 512, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 1024, 1024, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks)
])
if self.config_output_size == 256:
@@ -40,21 +39,21 @@ class AAD(nn.Module):
if self.config_output_size == 512:
layers.extend(
[
AdaptiveFeatureModulation(2048, 2048, 2048, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(2048, 2048, 4096, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(2048, 2048, 2048, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(2048, 1024, 1024, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 1024, 1024, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 1024, 2048, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 1024, 1536, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 1024, 768, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks)
])
if self.config_output_size == 1024:
layers.extend(
[
AdaptiveFeatureModulation(4096, 4096, 4096, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(4096, 4096, 8192, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(4096, 4096, 4096, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(4096, 2048, 2048, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(2048, 1024, 1024, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 1024, 2048, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 1024, 4096, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 1024, 3072, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 1024, 1536, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 1024, 768, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks)
])
@@ -100,7 +99,7 @@ class AdaptiveFeatureModulation(nn.Module):
primary_layers.extend(
[
FeatureModulation(self.context_input_channels, self.context_target_channels, self.context_source_channels),
nn.ReLU(inplace = True)
nn.ReLU()
])
if index < self.context_num_blocks - 1:
@@ -117,7 +116,7 @@ class AdaptiveFeatureModulation(nn.Module):
shortcut_layers.extend(
[
FeatureModulation(self.context_input_channels, self.context_target_channels, self.context_source_channels),
nn.ReLU(inplace = True),
nn.ReLU(),
nn.Conv2d(self.context_input_channels, self.context_output_channels, kernel_size = 3, padding = 1, bias = False)
])
+5 -5
View File
@@ -56,17 +56,17 @@ class BottleNeck(nn.Module):
def __init__(self, num_filters : int):
super().__init__()
self.sequences = self.create_sequences(num_filters)
self.relu = nn.ReLU(inplace = True)
self.relu = nn.ReLU()
@staticmethod
def create_sequences(num_filters : int) -> nn.Sequential:
return nn.Sequential(
nn.Conv2d(num_filters, num_filters, kernel_size = 3, padding = 1, bias = False),
nn.BatchNorm2d(num_filters),
nn.ReLU(inplace = True),
nn.ReLU(),
nn.Conv2d(num_filters, num_filters, kernel_size = 3, padding = 1, bias = False),
nn.BatchNorm2d(num_filters),
nn.ReLU(inplace = True)
nn.ReLU()
)
def forward(self, input_tensor : Tensor) -> Tensor:
@@ -84,7 +84,7 @@ class UpSample(nn.Module):
def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size = 2, stride = 2),
nn.ReLU(inplace = True)
nn.ReLU()
)
def forward(self, input_tensor : Tensor) -> Tensor:
@@ -102,7 +102,7 @@ class DownSample(nn.Module):
return nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1, bias = False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace = True),
nn.ReLU(),
nn.MaxPool2d(2)
)
+3 -3
View File
@@ -20,7 +20,7 @@ class NLD(nn.Module):
layers = nn.ModuleList(
[
nn.Conv2d(self.config_input_channels, current_filters, kernel_size = self.config_kernel_size, stride = 2, padding = padding),
nn.LeakyReLU(0.2, True)
nn.LeakyReLU(0.2)
])
for _ in range(1, self.config_num_layers):
@@ -30,7 +30,7 @@ class NLD(nn.Module):
[
nn.Conv2d(previous_filters, current_filters, kernel_size = self.config_kernel_size, stride = 2, padding = padding),
nn.InstanceNorm2d(current_filters),
nn.LeakyReLU(0.2, True)
nn.LeakyReLU(0.2)
]
previous_filters = current_filters
@@ -39,7 +39,7 @@ class NLD(nn.Module):
[
nn.Conv2d(previous_filters, current_filters, kernel_size = self.config_kernel_size, padding = padding),
nn.InstanceNorm2d(current_filters),
nn.LeakyReLU(0.2, True),
nn.LeakyReLU(0.2),
nn.Conv2d(current_filters, 1, kernel_size = self.config_kernel_size, padding = padding)
]
return layers
+18 -15
View File
@@ -41,8 +41,8 @@ class UNet(nn.Module):
down_samples.extend(
[
DownSample(512, 1024),
DownSample(1024, 2048),
DownSample(2048, 2048)
DownSample(1024, 1024),
DownSample(1024, 1024)
])
if self.config_output_size == 1024:
@@ -50,8 +50,8 @@ class UNet(nn.Module):
[
DownSample(512, 1024),
DownSample(1024, 2048),
DownSample(2048, 4096),
DownSample(4096, 4096)
DownSample(2048, 2048),
DownSample(2048, 2048)
])
return down_samples
@@ -62,36 +62,39 @@ class UNet(nn.Module):
if self.config_output_size == 128:
up_samples.extend(
[
UpSample(512, 512)
UpSample(512, 512),
UpSample(1024, 256)
])
if self.config_output_size == 256:
up_samples.extend(
[
UpSample(1024, 1024),
UpSample(2048, 512)
UpSample(2048, 512),
UpSample(1024, 256)
])
if self.config_output_size == 512:
up_samples.extend(
[
UpSample(2048, 2048),
UpSample(4096, 1024),
UpSample(2048, 512)
UpSample(1024, 1024),
UpSample(2048, 512),
UpSample(1536, 256),
UpSample(768, 256)
])
if self.config_output_size == 1024:
up_samples.extend(
[
UpSample(4096, 4096),
UpSample(8192, 2048),
UpSample(2048, 2048),
UpSample(4096, 1024),
UpSample(2048, 512)
UpSample(3072, 512),
UpSample(1536, 256),
UpSample(768, 256)
])
up_samples.extend(
[
UpSample(1024, 256),
UpSample(512, 128),
UpSample(256, 64),
UpSample(128, 32)
@@ -130,7 +133,7 @@ class UpSample(nn.Module):
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False),
nn.BatchNorm2d(output_channels),
nn.LeakyReLU(0.1, inplace = True)
nn.LeakyReLU(0.1)
)
def forward(self, input_tensor : Tensor, skip_tensor : Tensor) -> Tensor:
@@ -149,7 +152,7 @@ class DownSample(nn.Module):
return nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False),
nn.BatchNorm2d(output_channels),
nn.LeakyReLU(0.1, inplace = True)
nn.LeakyReLU(0.1)
)
def forward(self, input_tensor : Tensor) -> Tensor:
+28 -14
View File
@@ -1,8 +1,10 @@
import os
import shutil
import warnings
from configparser import ConfigParser
from copy import deepcopy
from typing import List, Tuple
from pathlib import Path
from typing import List, Tuple, cast
import torch
import torchvision
@@ -14,11 +16,11 @@ from torch.utils.data import ConcatDataset, Dataset, random_split
from torchdata.stateful_dataloader import StatefulDataLoader
from .dataset import DynamicDataset
from .helper import apply_noise, calculate_embedding, erode_mask, overlay_mask
from .helper import apply_noise, calculate_face_embedding, erode_mask, overlay_mask
from .models.discriminator import Discriminator
from .models.generator import Generator
from .models.loss import AdversarialLoss, CycleLoss, DiscriminatorLoss, FeatureLoss, GazeLoss, IdentityLoss, MaskLoss, ReconstructionLoss
from .types import Batch, Embedding, Mask, OptimizerSet
from .types import Batch, Embedding, Mask, OptimizerSet, TrainerPrecision, TrainerStrategy
warnings.filterwarnings('ignore', category = UserWarning, module = 'torch')
@@ -34,6 +36,7 @@ class HyperSwapTrainer(LightningModule):
self.config_gazer_path = config_parser.get('training.model', 'gazer_path')
self.config_face_masker_path = config_parser.get('training.model', 'face_masker_path')
self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size')
self.config_discriminator_ratio = config_parser.getfloat('training.trainer', 'discriminator_ratio')
self.config_gradient_clip = config_parser.getfloat('training.trainer', 'gradient_clip')
self.config_preview_frequency = config_parser.getint('training.trainer', 'preview_frequency')
self.config_mask_factor = config_parser.getfloat('training.modifier', 'mask_factor')
@@ -101,8 +104,8 @@ class HyperSwapTrainer(LightningModule):
do_update = (batch_index + 1) % self.config_accumulate_size == 0
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
generator_scheduler, discriminator_scheduler = self.lr_schedulers() #type:ignore[attr-defined]
source_embedding = calculate_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0))
target_embedding = calculate_embedding(self.generator_embedder, target_tensor, (0, 0, 0, 0))
source_embedding = calculate_face_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0))
target_embedding = calculate_face_embedding(self.generator_embedder, target_tensor, (0, 0, 0, 0))
if self.config_noise_factor > 0:
source_embedding = apply_noise(source_embedding, self.config_noise_factor)
@@ -123,9 +126,12 @@ class HyperSwapTrainer(LightningModule):
mask_loss, weighted_mask_loss = self.mask_loss(target_tensor, generator_output_mask)
generator_loss = weighted_adversarial_loss + weighted_cycle_loss + weighted_feature_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_gaze_loss + weighted_mask_loss
discriminator_source_tensors = self.discriminator(source_tensor)
discriminator_output_tensors = self.discriminator(generator_output_tensor.detach())
discriminator_loss = self.discriminator_loss(discriminator_source_tensors, discriminator_output_tensors)
if torch.randn(1).item() < self.config_discriminator_ratio:
discriminator_real_tensors = self.discriminator(source_tensor)
else:
discriminator_real_tensors = self.discriminator(target_tensor)
discriminator_fake_tensors = self.discriminator(generator_output_tensor.detach())
discriminator_loss = self.discriminator_loss(discriminator_real_tensors, discriminator_fake_tensors)
self.toggle_optimizer(generator_optimizer)
self.manual_backward(generator_loss)
@@ -176,9 +182,9 @@ class HyperSwapTrainer(LightningModule):
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
source_tensor, target_tensor = batch
source_embedding = calculate_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0))
source_embedding = calculate_face_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0))
output_tensor, _ = self.forward(source_embedding, target_tensor)
output_embedding = calculate_embedding(self.generator_embedder, output_tensor, (0, 0, 0, 0))
output_embedding = calculate_face_embedding(self.generator_embedder, output_tensor, (0, 0, 0, 0))
validation_score = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5
self.log('validation_score', validation_score, sync_dist = True, prog_bar = True)
return validation_score
@@ -197,6 +203,13 @@ class HyperSwapTrainer(LightningModule):
self.logger.experiment.add_image('preview', preview_grid, self.global_step) # type:ignore[attr-defined]
class ModelWithConfigCheckpoint(ModelCheckpoint):
def _save_checkpoint(self, trainer : Trainer, checkpoint_path : str) -> None:
super()._save_checkpoint(trainer, checkpoint_path)
config_path = Path(checkpoint_path).with_suffix('.ini')
shutil.copy('config.ini', config_path)
def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]:
config_batch_size = CONFIG_PARSER.getint('training.loader', 'batch_size')
config_num_workers = CONFIG_PARSER.getint('training.loader', 'num_workers')
@@ -239,23 +252,24 @@ def prepare_datasets(config_parser : ConfigParser) -> List[Dataset[Tensor]]:
def create_trainer() -> Trainer:
config_max_epochs = CONFIG_PARSER.getint('training.trainer', 'max_epochs')
config_strategy = CONFIG_PARSER.get('training.trainer', 'strategy')
config_precision = CONFIG_PARSER.get('training.trainer', 'precision')
config_strategy = cast(TrainerStrategy, CONFIG_PARSER.get('training.trainer', 'strategy'))
config_precision = cast(TrainerPrecision, CONFIG_PARSER.get('training.trainer', 'precision'))
config_sync_batchnorm = CONFIG_PARSER.getboolean('training.trainer', 'sync_batchnorm')
config_logger_path = CONFIG_PARSER.get('training.logger', 'logger_path')
config_logger_name = CONFIG_PARSER.get('training.logger', 'logger_name')
config_directory_path = CONFIG_PARSER.get('training.output', 'directory_path')
config_file_pattern = CONFIG_PARSER.get('training.output', 'file_pattern')
logger = TensorBoardLogger(config_logger_path, config_logger_name)
return Trainer(
logger = logger,
log_every_n_steps = 10,
max_epochs = config_max_epochs,
strategy = config_strategy,
precision = config_precision,
sync_batchnorm = config_sync_batchnorm,
callbacks =
[
ModelCheckpoint(
ModelWithConfigCheckpoint(
monitor = 'generator_loss',
dirpath = config_directory_path,
filename = config_file_pattern,
+3
View File
@@ -23,3 +23,6 @@ GazerModule : TypeAlias = Module
FaceMaskerModule : TypeAlias = Module
OptimizerSet : TypeAlias = Any
TrainerStrategy = Literal['auto', 'ddp', 'ddp_spawn', 'ddp_find_unused_parameters_true']
TrainerPrecision = Literal['64-true', '32-true', '16-true', '16-mixed', 'bf16-true', 'bf16-mixed', 'transformer-engine', 'transformer-engine-float16']
-1
View File
@@ -16,7 +16,6 @@ def test_aad_with_unet(output_size : int) -> None:
'training.model.generator':
{
'source_channels': '512',
'output_channels': str(output_size * 16),
'output_size': str(output_size),
'num_blocks': '2'
}
+4 -4
View File
@@ -1,10 +1,10 @@
--extra-index-url https://download.pytorch.org/whl/cu128
albumentations==2.0.8
lightning==2.5.1
lightning==2.5.5
onnx==1.18.0
onnxruntime==1.22.0
pytorch-msssim==1.0.0
torch==2.7.1
torch==2.8.0
torchdata==0.11.0
torchvision==0.22.1
tensorboard==2.19.0
torchvision==0.23.0
tensorboard==2.20.0