Merge pull request #72 from facefusion/next

Next
This commit is contained in:
Henry Ruhs
2025-04-23 20:50:08 +02:00
committed by GitHub
51 changed files with 1958 additions and 379 deletions
+2 -4
View File
@@ -1,7 +1,5 @@
[flake8]
select = E3, E4, F, I1, I2
select = E22, E23, E24, E27, E3, E4, E7, F, I1, I2
plugins = flake8-import-order
application_import_names = arcface_converter
application_import_names = embedding_converter, face_swapper
import-order-style = pycharm
per-file-ignores = preparing.py:E402

Before

Width:  |  Height:  |  Size: 1.3 MiB

After

Width:  |  Height:  |  Size: 1.3 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.2 MiB

+16 -4
View File
@@ -8,12 +8,24 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set up Python 3.10
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: '3.10'
python-version: '3.12'
- run: pip install flake8
- run: pip install flake8-import-order
- run: pip install mypy
- run: flake8 arcface_converter
- run: mypy arcface_converter
- run: flake8 embedding_converter face_swapper
- run: mypy embedding_converter face_swapper
test:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: '3.12'
- run: pip install torch torchvision
- run: pip install pytest
- run: PYTHONPATH=/home/runner/work/facefusion-labs/facefusion-labs pytest
+7
View File
@@ -1,3 +1,10 @@
__pycache__
.assets
.datasets
.idea
.inputs
.exports
.logs
.models
.outputs
.vscode
-3
View File
@@ -1,3 +0,0 @@
MIT license
Copyright (c) 2024 Henry Ruhs
-93
View File
@@ -1,93 +0,0 @@
ArcFace Converter
=================
> Convert face embeddings between various ArcFace models.
![License](https://img.shields.io/badge/license-MIT-green)
Preview
-------
![Preview](https://raw.githubusercontent.com/facefusion/facefusion-labs/master/.github/preview_arcface_converter.png?sanitize=true)
Installation
------------
```
pip install -r requirements.txt
```
Example
-------
This example utilizes the MegaFace dataset to train an ArcFace Converter for SimSwap.
```
[preparing.dataset]
dataset_path = datasets/megaface/train.rec
crop_size = 112
process_limit = 650000
[preparing.model]
source_path = models/arcface_w600k_r50.onnx
target_path = models/arcface_simswap.onnx
[preparing.input]
directory_path = inputs
source_path = inputs/arcface_w600k_r50.npy
target_path = inputs/arcface_simswap.npy
[training.loader]
split_ratio = 0.8
batch_size = 51200
num_workers = 8
[training.trainer]
max_epochs = 4096
[training.output]
directory_path = outputs
file_pattern = arcface_converter_simswap_{epoch:02d}_{val_loss:.4f}
[exporting]
directory_path = exports
source_path = outputs/last.ckpt
target_path = exports/arcface_converter_simswap.onnx
opset_version = 15
[execution]
providers = CUDAExecutionProvider
```
Preparing
---------
Prepare the face embedding pairs.
```
python prepare.py
```
Training
--------
Train the ArcFace converter model.
```
python train.py
```
Exporting
---------
Export the model to ONNX.
```
python export.py
```
-22
View File
@@ -1,22 +0,0 @@
import configparser
from os import makedirs
import torch
from .training import ArcFaceConverterTrainer
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
def export() -> None:
directory_path = CONFIG.get('exporting', 'directory_path')
source_path = CONFIG.get('exporting', 'source_path')
target_path = CONFIG.get('exporting', 'target_path')
opset_version = CONFIG.getint('exporting', 'opset_version')
makedirs(directory_path, exist_ok = True)
model = ArcFaceConverterTrainer.load_from_checkpoint(source_path, map_location = 'cpu')
model.eval()
input_tensor = torch.randn(1, 512)
torch.onnx.export(model, input_tensor, target_path, input_names = [ 'input' ], output_names = [ 'output' ], opset_version = opset_version)
-21
View File
@@ -1,21 +0,0 @@
import torch
import torch.nn as nn
from torch import Tensor
class ArcFaceConverter(nn.Module):
def __init__(self) -> None:
super(ArcFaceConverter, self).__init__()
self.fc1 = nn.Linear(512, 1024)
self.fc2 = nn.Linear(1024, 2048)
self.fc3 = nn.Linear(2048, 1024)
self.fc4 = nn.Linear(1024, 512)
self.activation = nn.LeakyReLU()
def forward(self, inputs : Tensor) -> Tensor:
norm_inputs = inputs / torch.norm(inputs)
outputs = self.activation(self.fc1(norm_inputs))
outputs = self.activation(self.fc2(outputs))
outputs = self.activation(self.fc3(outputs))
outputs = self.fc4(outputs)
return outputs
-79
View File
@@ -1,79 +0,0 @@
import configparser
from os import makedirs
from os.path import isfile
from typing import List
import numpy
numpy.bool = numpy.bool_
from mxnet.io import ImageRecordIter
from onnxruntime import InferenceSession
from tqdm import tqdm
from .typing import Embedding, EmbeddingPairs, VisionFrame
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
def prepare_crop_vision_frame(crop_vision_frame : VisionFrame) -> VisionFrame:
crop_vision_frame = crop_vision_frame.astype(numpy.float32) / 255
crop_vision_frame = (crop_vision_frame - 0.5) * 2
return crop_vision_frame
def create_inference_session(model_path : str, execution_providers : List[str]) -> InferenceSession:
inference_session = InferenceSession(model_path, providers = execution_providers)
return inference_session
def forward(inference_session : InferenceSession, crop_vision_frame : VisionFrame) -> Embedding:
embedding = inference_session.run(None,
{
'input': crop_vision_frame
})[0]
return embedding
def process_embeddings(dataset_reader : ImageRecordIter, source_inference_session : InferenceSession, target_inference_session : InferenceSession) -> EmbeddingPairs:
dataset_process_limit = CONFIG.getint('preparing.dataset', 'process_limit')
embedding_pairs = []
with tqdm(total = dataset_process_limit) as progress:
for batch in dataset_reader:
crop_vision_frame = batch.data[0].asnumpy()
crop_vision_frame = prepare_crop_vision_frame(crop_vision_frame)
source_embedding = forward(source_inference_session, crop_vision_frame)
target_embedding = forward(target_inference_session, crop_vision_frame)
embedding_pairs.append([ source_embedding, target_embedding ])
progress.update()
if progress.n == dataset_process_limit:
return numpy.concatenate(embedding_pairs, axis = 1).T
return numpy.concatenate(embedding_pairs, axis = 1).T
def prepare() -> None:
dataset_path = CONFIG.get('preparing.dataset', 'dataset_path')
dataset_crop_size = CONFIG.getint('preparing.dataset', 'crop_size')
model_source_path = CONFIG.get('preparing.model', 'source_path')
model_target_path = CONFIG.get('preparing.model', 'target_path')
input_directory_path = CONFIG.get('preparing.input', 'directory_path')
input_source_path = CONFIG.get('preparing.input', 'source_path')
input_target_path = CONFIG.get('preparing.input', 'target_path')
execution_providers = CONFIG.get('execution', 'providers').split(' ')
makedirs(input_directory_path, exist_ok = True)
if isfile(dataset_path) and isfile(model_source_path) and isfile(model_target_path):
dataset_reader = ImageRecordIter(
path_imgrec = dataset_path,
data_shape = (3, dataset_crop_size, dataset_crop_size),
batch_size = 1,
shuffle = False
)
source_inference_session = create_inference_session(model_source_path, execution_providers)
target_inference_session = create_inference_session(model_target_path, execution_providers)
embedding_pairs = process_embeddings(dataset_reader, source_inference_session, target_inference_session)
numpy.save(input_source_path, embedding_pairs[..., 0].T)
numpy.save(input_target_path, embedding_pairs[..., 1].T)
-116
View File
@@ -1,116 +0,0 @@
import configparser
from typing import Any, Tuple
import numpy
import pytorch_lightning
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.tuner.tuning import Tuner
from torch import Tensor
from torch.utils.data import DataLoader, Dataset, TensorDataset, random_split
from .model import ArcFaceConverter
from .typing import Batch, Loader
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
class ArcFaceConverterTrainer(pytorch_lightning.LightningModule):
def __init__(self) -> None:
super(ArcFaceConverterTrainer, self).__init__()
self.model = ArcFaceConverter()
self.loss_fn = torch.nn.MSELoss()
self.lr = 0.001
def forward(self, source_embedding : Tensor) -> Tensor:
return self.model(source_embedding)
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
source_embedding, target_embedding = batch
output_embedding = self(source_embedding)
loss = self.loss_fn(output_embedding, target_embedding)
self.log('train_loss', loss, prog_bar = True, logger = True)
return loss
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
source_embedding, target_embedding = batch
output_embedding = self(source_embedding)
loss = self.loss_fn(output_embedding, target_embedding)
self.log('val_loss', loss, prog_bar = True, logger = True)
return loss
def configure_optimizers(self) -> Any:
optimizer = torch.optim.Adam(self.parameters(), lr = self.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
return\
{
'optimizer': optimizer,
'lr_scheduler':
{
'scheduler': scheduler,
'monitor': 'train_loss',
'interval': 'epoch',
'frequency': 1
}
}
def create_loaders() -> Tuple[Loader, Loader]:
loader_batch_size = CONFIG.getint('training.loader', 'batch_size')
loader_num_workers = CONFIG.getint('training.loader', 'num_workers')
training_dataset, validate_dataset = split_dataset()
training_loader = DataLoader(training_dataset, batch_size = loader_batch_size, num_workers = loader_num_workers, shuffle = True, pin_memory = True)
validation_loader = DataLoader(validate_dataset, batch_size = loader_batch_size, num_workers = loader_num_workers, shuffle = False, pin_memory = True)
return training_loader, validation_loader
def split_dataset() -> Tuple[Dataset[Any], Dataset[Any]]:
input_source_path = CONFIG.get('preparing.input', 'source_path')
input_target_path = CONFIG.get('preparing.input', 'target_path')
loader_split_ratio = CONFIG.getfloat('training.loader', 'split_ratio')
source_input = torch.from_numpy(numpy.load(input_source_path)).float()
target_input = torch.from_numpy(numpy.load(input_target_path)).float()
dataset = TensorDataset(source_input, target_input)
dataset_size = len(dataset)
training_size = int(loader_split_ratio * len(dataset))
validation_size = int(dataset_size - training_size)
training_dataset, validate_dataset = random_split(dataset, [ training_size, validation_size ])
return training_dataset, validate_dataset
def create_trainer() -> Trainer:
trainer_max_epochs = CONFIG.getint('training.trainer', 'max_epochs')
output_directory_path = CONFIG.get('training.output', 'directory_path')
output_file_pattern = CONFIG.get('training.output', 'file_pattern')
return Trainer(
max_epochs = trainer_max_epochs,
callbacks =
[
ModelCheckpoint(
monitor = 'train_loss',
dirpath = output_directory_path,
filename = output_file_pattern,
every_n_epochs = 10,
save_top_k = 3,
save_last = True
)
],
enable_progress_bar = True,
log_every_n_steps = 2
)
def train() -> None:
trainer = create_trainer()
training_loader, validation_loader = create_loaders()
model = ArcFaceConverterTrainer()
tuner = Tuner(trainer)
tuner.lr_find(model, training_loader, validation_loader)
trainer.fit(model, training_loader, validation_loader)
-13
View File
@@ -1,13 +0,0 @@
from typing import Any, Tuple
from numpy.typing import NDArray
from torch import Tensor
from torch.utils.data import DataLoader
Batch = Tuple[Tensor, Tensor]
Loader = DataLoader[Tuple[Tensor, ...]]
Embedding = NDArray[Any]
EmbeddingPairs = NDArray[Any]
FaceLandmark5 = NDArray[Any]
VisionFrame = NDArray[Any]
+3
View File
@@ -0,0 +1,3 @@
OpenRAIL-MS license
Copyright (c) 2025 Henry Ruhs
+96
View File
@@ -0,0 +1,96 @@
Embedding Converter
===================
> Convert face embeddings between various models.
![License](https://img.shields.io/badge/license-OpenRAIL--MS-green)
Preview
-------
![Preview](https://raw.githubusercontent.com/facefusion/facefusion-labs/next/.github/previews/embedding_converter.png?sanitize=true)
Installation
------------
```
pip install -r requirements.txt
```
Setup
-----
This `config.ini` utilizes the MegaFace dataset to train the Embedding Converter for SimSwap.
```
[training.dataset]
file_pattern = .datasets/megaface/**/*.jpg
```
```
[training.loader]
batch_size = 256
num_workers = 8
split_ratio = 0.95
```
```
[training.model]
source_path = .models/arcface_w600k_r50.pt
target_path = .models/arcface_simswap.pt
```
```
[training.trainer]
learning_rate = 0.001
max_epochs = 4096
strategy = auto
precision = 16-mixed
logger_path = .logs
logger_name = arcface_converter_simswap
```
```
[training.output]
directory_path = .outputs
file_pattern = arcface_converter_simswap_{epoch}_{step}
resume_path = .outputs/last.ckpt
```
```
[exporting]
directory_path = .exports
source_path = .outputs/last.ckpt
target_path = .exports/arcface_converter_simswap.onnx
ir_version = 10
opset_version = 15
```
Training
--------
Train the Embedding Converter model.
```
python train.py
```
Launch the TensorBoard to monitor the training.
```
tensorboard --logdir=.logs
```
Exporting
---------
Export the model to ONNX.
```
python export.py
```
@@ -1,34 +1,31 @@
[preparing.dataset]
dataset_path =
crop_size =
process_limit =
[preparing.model]
source_path =
target_path =
[preparing.input]
directory_path =
source_path =
target_path =
[training.dataset]
file_pattern =
[training.loader]
split_ratio =
batch_size =
num_workers =
split_ratio =
[training.model]
source_path =
target_path =
[training.trainer]
learning_rate =
max_epochs =
strategy =
precision =
logger_path =
logger_name =
[training.output]
directory_path =
file_pattern =
resume_path =
[exporting]
directory_path =
source_path =
target_path =
ir_version =
opset_version =
[execution]
providers =
+34
View File
@@ -0,0 +1,34 @@
import glob
from configparser import ConfigParser
from torch import Tensor
from torch.utils.data import Dataset
from torchvision import io, transforms
from .types import Batch
class StaticDataset(Dataset[Tensor]):
def __init__(self, config_parser : ConfigParser) -> None:
self.config_file_pattern = config_parser.get('training.dataset', 'file_pattern')
self.file_paths = glob.glob(self.config_file_pattern)
self.transforms = self.compose_transforms()
def __getitem__(self, index : int) -> Batch:
file_path = self.file_paths[index]
temp_tensor = io.read_image(file_path)
return self.transforms(temp_tensor)
def __len__(self) -> int:
return len(self.file_paths)
@staticmethod
def compose_transforms() -> transforms:
return transforms.Compose(
[
transforms.ToPILImage(),
transforms.Resize((112, 112), interpolation = transforms.InterpolationMode.BICUBIC),
transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
+23
View File
@@ -0,0 +1,23 @@
import os
from configparser import ConfigParser
import torch
from .training import EmbeddingConverterTrainer
CONFIG_PARSER = ConfigParser()
CONFIG_PARSER.read('config.ini')
def export() -> None:
config_directory_path = CONFIG_PARSER.get('exporting', 'directory_path')
config_source_path = CONFIG_PARSER.get('exporting', 'source_path')
config_target_path = CONFIG_PARSER.get('exporting', 'target_path')
config_ir_version = CONFIG_PARSER.getint('exporting', 'ir_version')
config_opset_version = CONFIG_PARSER.getint('exporting', 'opset_version')
os.makedirs(config_directory_path, exist_ok = True)
model = EmbeddingConverterTrainer.load_from_checkpoint(config_source_path, map_location = 'cpu').eval()
model.ir_version = torch.tensor(config_ir_version)
input_tensor = torch.randn(1, 512)
torch.onnx.export(model, input_tensor, config_target_path, input_names = [ 'input' ], output_names = [ 'output' ], opset_version = config_opset_version)
@@ -0,0 +1,28 @@
import torch
from torch import Tensor, nn
class EmbeddingConverter(nn.Module):
def __init__(self) -> None:
super().__init__()
self.layers = self.create_layers()
self.leaky_relu = nn.LeakyReLU()
@staticmethod
def create_layers() -> nn.ModuleList:
return nn.ModuleList(
[
nn.Linear(512, 1024),
nn.Linear(1024, 2048),
nn.Linear(2048, 1024),
nn.Linear(1024, 512)
])
def forward(self, input_tensor : Tensor) -> Tensor:
output_tensor = input_tensor / torch.norm(input_tensor)
for layer in self.layers[:-1]:
output_tensor = self.leaky_relu(layer(output_tensor))
output_tensor = self.layers[-1](output_tensor)
return output_tensor
+134
View File
@@ -0,0 +1,134 @@
import os
from configparser import ConfigParser
from typing import Tuple
import torch
from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from torch import Tensor, nn
from torch.utils.data import Dataset, random_split
from torchdata.stateful_dataloader import StatefulDataLoader
from .dataset import StaticDataset
from .models.embedding_converter import EmbeddingConverter
from .types import Batch, Embedding, OptimizerSet
CONFIG_PARSER = ConfigParser()
CONFIG_PARSER.read('config.ini')
class EmbeddingConverterTrainer(LightningModule):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config_source_path = config_parser.get('training.model', 'source_path')
self.config_target_path = config_parser.get('training.model', 'target_path')
self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate')
self.embedding_converter = EmbeddingConverter()
self.source_embedder = torch.jit.load(self.config_source_path, map_location = 'cpu').eval()
self.target_embedder = torch.jit.load(self.config_target_path, map_location = 'cpu').eval()
self.mse_loss = nn.MSELoss()
def forward(self, source_embedding : Embedding) -> Embedding:
return self.embedding_converter(source_embedding)
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
with torch.no_grad():
source_embedding = self.source_embedder(batch)
target_embedding = self.target_embedder(batch)
output_embedding = self(source_embedding)
training_loss = self.mse_loss(output_embedding, target_embedding)
self.log('training_loss', training_loss, prog_bar = True)
return training_loss
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
with torch.no_grad():
source_embedding = self.source_embedder(batch)
output_embedding = self(source_embedding)
validation_score = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5
self.log('validation_score', validation_score, sync_dist = True, prog_bar = True)
return validation_score
def configure_optimizers(self) -> OptimizerSet:
optimizer = torch.optim.Adam(self.parameters(), lr = self.config_learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
optimizer_set =\
{
'optimizer': optimizer,
'lr_scheduler':
{
'scheduler': scheduler,
'monitor': 'training_loss',
'interval': 'epoch',
'frequency': 1
}
}
return optimizer_set
def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]:
config_batch_size = CONFIG_PARSER.getint('training.loader', 'batch_size')
config_num_workers = CONFIG_PARSER.getint('training.loader', 'num_workers')
training_dataset, validate_dataset = split_dataset(dataset)
training_loader = StatefulDataLoader(training_dataset, batch_size = config_batch_size, shuffle = True, num_workers = config_num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
validation_loader = StatefulDataLoader(validate_dataset, batch_size = config_batch_size, shuffle = False, num_workers = config_num_workers, pin_memory = True, persistent_workers = True)
return training_loader, validation_loader
def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[Tensor]]:
config_split_ratio = CONFIG_PARSER.getfloat('training.loader', 'split_ratio')
dataset_size = len(dataset) # type:ignore[arg-type]
training_size = int(dataset_size * config_split_ratio)
validation_size = int(dataset_size - training_size)
training_dataset, validate_dataset = random_split(dataset, [ training_size, validation_size ])
return training_dataset, validate_dataset
def create_trainer() -> Trainer:
config_max_epochs = CONFIG_PARSER.getint('training.trainer', 'max_epochs')
config_strategy = CONFIG_PARSER.get('training.trainer', 'strategy')
config_precision = CONFIG_PARSER.get('training.trainer', 'precision')
config_logger_path = CONFIG_PARSER.get('training.trainer', 'logger_path')
config_logger_name = CONFIG_PARSER.get('training.trainer', 'logger_name')
config_directory_path = CONFIG_PARSER.get('training.output', 'directory_path')
config_file_pattern = CONFIG_PARSER.get('training.output', 'file_pattern')
logger = TensorBoardLogger(config_logger_path, config_logger_name)
return Trainer(
logger = logger,
log_every_n_steps = 10,
max_epochs = config_max_epochs,
strategy = config_strategy,
precision = config_precision,
callbacks =
[
ModelCheckpoint(
monitor = 'training_loss',
dirpath = config_directory_path,
filename = config_file_pattern,
every_n_epochs = 1,
save_top_k = 3,
save_last = True
)
]
)
def train() -> None:
config_resume_path = CONFIG_PARSER.get('training.output', 'resume_path')
if torch.cuda.is_available():
torch.set_float32_matmul_precision('high')
dataset = StaticDataset(CONFIG_PARSER)
training_loader, validation_loader = create_loaders(dataset)
embedding_converter_trainer = EmbeddingConverterTrainer(CONFIG_PARSER)
trainer = create_trainer()
if os.path.exists(config_resume_path):
trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config_resume_path)
else:
trainer.fit(embedding_converter_trainer, training_loader, validation_loader)
+8
View File
@@ -0,0 +1,8 @@
from typing import Any, TypeAlias
from torch import Tensor
Batch : TypeAlias = Tensor
Embedding : TypeAlias = Tensor
OptimizerSet : TypeAlias = Any
+3
View File
@@ -0,0 +1,3 @@
ResearchRAIL-MS license
Copyright (c) 2025 Henry Ruhs
+160
View File
@@ -0,0 +1,160 @@
Face Swapper
============
> Face shape and occlusion aware identity transfer.
![License](https://img.shields.io/badge/license-ResearchRAIL--MS-red)
Preview
-------
![Preview](https://raw.githubusercontent.com/facefusion/facefusion-labs/next/.github/previews/face_swapper.png?sanitize=true)
Installation
------------
```
pip install -r requirements.txt
```
Setup
-----
This `config.ini` utilizes the MegaFace dataset to train the Face Swapper model.
```
[training.dataset]
file_pattern = .datasets/vggface2/**/*.jpg
warp_template = vgg_face_hq_to_arcface_128_v2
transform_size = 256
batch_mode = equal
batch_ratio = 0.2
```
```
[training.loader]
batch_size = 8
num_workers = 8
split_ratio = 0.9995
```
```
[training.model]
generator_embedder_path = .models/blendface.pt
loss_embedder_path = .models/adaface.pt
gazer_path = .models/gazer.pt
face_masker_path = .models/face_masker.pt
```
```
[training.model.generator]
source_channels = 512
output_channels = 4096
output_size = 256
num_blocks = 2
```
```
[training.model.discriminator]
input_channels = 3
num_filters = 64
num_layers = 5
num_discriminators = 3
kernel_size = 4
```
```
[training.model.masker]
input_channels = 67
output_channels = 1
num_filters = 16
```
```
[training.losses]
adversarial_weight = 1.0
cycle_weight = 1.0
feature_weight = 10.0
reconstruction_weight = 10.0
identity_weight = 20.0
gaze_weight = 0.05
mask_weight = 5.0
```
```
[training.trainer]
accumulate_size = 4
learning_rate = 0.0004
max_epochs = 50
strategy = auto
precision = 16-mixed
logger_path = .logs
logger_name = face_swapper
preview_frequency = 100
```
```
[training.output]
directory_path = .outputs
file_pattern = face_swapper_{epoch}_{step}
resume_path = .outputs/last.ckpt
```
```
[exporting]
directory_path = .exports
source_path = .outputs/last.ckpt
target_path = .exports/face_swapper.onnx
target_size = 256
ir_version = 10
opset_version = 15
precision = full
```
```
[inferencing]
generator_path = .outputs/last.ckpt
embedder_path = .models/arcface.pt
source_path = .assets/source.jpg
target_path = .assets/target.jpg
output_path = .outputs/output.jpg
```
Training
--------
Train the Face Swapper model.
```
python train.py
```
Launch the TensorBoard to monitor the training.
```
tensorboard --logdir=.logs
```
Exporting
---------
Export the model to ONNX.
```
python export.py
```
Inferencing
-----------
Inference the model.
```
python infer.py
```
View File
+75
View File
@@ -0,0 +1,75 @@
[training.dataset]
file_pattern =
warp_template =
transform_size =
batch_mode =
batch_ratio =
[training.loader]
batch_size =
num_workers =
split_ratio =
[training.model]
generator_embedder_path =
loss_embedder_path =
gazer_path =
face_masker_path =
[training.model.generator]
source_channels =
output_channels =
output_size =
num_blocks =
[training.model.discriminator]
input_channels =
num_filters =
num_layers =
num_discriminators =
kernel_size =
[training.model.masker]
input_channels =
output_channels =
num_filters =
[training.losses]
adversarial_weight =
cycle_weight =
feature_weight =
reconstruction_weight =
identity_weight =
gaze_weight =
mask_weight =
[training.trainer]
accumulate_size =
learning_rate =
max_epochs =
strategy =
precision =
logger_path =
logger_name =
preview_frequency =
[training.output]
directory_path =
file_pattern =
resume_path =
[exporting]
directory_path =
source_path =
target_path =
target_size =
ir_version =
opset_version =
precision =
[inferencing]
generator_path =
embedder_path =
source_path =
target_path =
output_path =
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
from src.preparing import prepare
from src.exporting import export
if __name__ == '__main__':
prepare()
export()
+6
View File
@@ -0,0 +1,6 @@
#!/usr/bin/env python3
from src.inferencing import infer
if __name__ == '__main__':
infer()
View File
+107
View File
@@ -0,0 +1,107 @@
import glob
import os
import random
from configparser import ConfigParser
from typing import cast
import albumentations
from torch import Tensor
from torch.utils.data import Dataset
from torchvision import io, transforms
from .helper import warp_tensor
from .types import Batch, BatchMode, WarpTemplate
class DynamicDataset(Dataset[Tensor]):
def __init__(self, config_parser : ConfigParser) -> None:
self.config_file_pattern = config_parser.get('training.dataset', 'file_pattern')
self.config_transform_size = config_parser.getint('training.dataset', 'transform_size')
self.config_batch_mode = cast(BatchMode, config_parser.get('training.dataset', 'batch_mode'))
self.config_batch_ratio = config_parser.getfloat('training.dataset', 'batch_ratio')
self.config_parser = config_parser
self.file_paths = glob.glob(self.config_file_pattern)
self.transforms = self.compose_transforms()
def __getitem__(self, index : int) -> Batch:
file_path = self.file_paths[index]
if random.random() < self.config_batch_ratio:
if self.config_batch_mode == 'equal':
return self.prepare_equal_batch(file_path)
if self.config_batch_mode == 'same':
return self.prepare_same_batch(file_path)
return self.prepare_different_batch(file_path)
def __len__(self) -> int:
return len(self.file_paths)
def compose_transforms(self) -> transforms:
return transforms.Compose(
[
AugmentTransform(),
transforms.ToPILImage(),
transforms.Resize((self.config_transform_size, self.config_transform_size), interpolation = transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
WarpTransform(self.config_parser),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
def prepare_different_batch(self, source_path : str) -> Batch:
target_path = random.choice(self.file_paths)
source_tensor = io.read_image(source_path)
source_tensor = self.transforms(source_tensor)
target_tensor = io.read_image(target_path)
target_tensor = self.transforms(target_tensor)
return source_tensor, target_tensor
def prepare_equal_batch(self, source_path : str) -> Batch:
source_tensor = io.read_image(source_path)
source_tensor = self.transforms(source_tensor)
return source_tensor, source_tensor
def prepare_same_batch(self, source_path : str) -> Batch:
target_directory_path = os.path.dirname(source_path)
target_file_name_and_extension = random.choice(os.listdir(target_directory_path))
target_path = os.path.join(target_directory_path, target_file_name_and_extension)
source_tensor = io.read_image(source_path)
source_tensor = self.transforms(source_tensor)
target_tensor = io.read_image(target_path)
target_tensor = self.transforms(target_tensor)
return source_tensor, target_tensor
class AugmentTransform:
def __init__(self) -> None:
self.transforms = self.compose_transforms()
def __call__(self, input_tensor : Tensor) -> Tensor:
temp_tensor = input_tensor.numpy().transpose(1, 2, 0)
return self.transforms(image = temp_tensor).get('image')
@staticmethod
def compose_transforms() -> albumentations.Compose:
return albumentations.Compose(
[
albumentations.HorizontalFlip(),
albumentations.OneOf(
[
albumentations.MotionBlur(p = 0.1),
albumentations.ZoomBlur(max_factor = (1.0, 1.1), p = 0.1)
], p = 0.2),
albumentations.RandomBrightnessContrast(p = 0.7),
albumentations.ColorJitter(p = 0.2),
albumentations.RGBShift(p = 0.7),
albumentations.Illumination(p = 0.2),
albumentations.Affine(translate_percent = (-0.03, 0.03), scale = (0.98, 1.02), rotate = (-2, 2), border_mode = 1, p = 0.3)
])
class WarpTransform:
def __init__(self, config_parser : ConfigParser) -> None:
self.config_warp_template = cast(WarpTemplate, config_parser.get('training.dataset', 'warp_template'))
def __call__(self, input_tensor : Tensor) -> Tensor:
temp_tensor = input_tensor.unsqueeze(0)
return warp_tensor(temp_tensor, self.config_warp_template).squeeze(0)
+47
View File
@@ -0,0 +1,47 @@
import os
from configparser import ConfigParser
from typing import Tuple
import torch
from torch import Tensor, nn
from .training import FaceSwapperTrainer
from .types import Embedding, Mask, Module
CONFIG_PARSER = ConfigParser()
CONFIG_PARSER.read('config.ini')
class HalfPrecision(nn.Module):
def __init__(self, model : Module) -> None:
super().__init__()
self.model = model.half()
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Mask]:
source_embedding = source_embedding.half()
target_tensor = target_tensor.half()
output_tensor, output_mask = self.model(source_embedding, target_tensor)
output_tensor = output_tensor.float()
output_mask = output_mask.float()
return output_tensor, output_mask
def export() -> None:
config_directory_path = CONFIG_PARSER.get('exporting', 'directory_path')
config_source_path = CONFIG_PARSER.get('exporting', 'source_path')
config_target_path = CONFIG_PARSER.get('exporting', 'target_path')
config_target_size = CONFIG_PARSER.getint('exporting', 'target_size')
config_ir_version = CONFIG_PARSER.getint('exporting', 'ir_version')
config_opset_version = CONFIG_PARSER.getint('exporting', 'opset_version')
config_precision = CONFIG_PARSER.get('exporting', 'precision')
os.makedirs(config_directory_path, exist_ok = True)
model = FaceSwapperTrainer.load_from_checkpoint(config_source_path, config_parser = CONFIG_PARSER, map_location = 'cpu').eval()
if config_precision == 'half':
model = HalfPrecision(model).eval()
model.ir_version = torch.tensor(config_ir_version)
source_tensor = torch.randn(1, 512)
target_tensor = torch.randn(1, 3, config_target_size, config_target_size)
torch.onnx.export(model, (source_tensor, target_tensor), config_target_path, input_names = [ 'source', 'target' ], output_names = [ 'output', 'mask' ], opset_version = config_opset_version)
+51
View File
@@ -0,0 +1,51 @@
import torch
from torch import Tensor, nn
from .types import EmbedderModule, Embedding, Mask, Padding, WarpTemplate, WarpTemplateSet
WARP_TEMPLATE_SET : WarpTemplateSet =\
{
'arcface_128_v2_to_arcface_112_v2': torch.tensor(
[
[ 8.75000016e-01, -1.07193451e-08, 3.80446920e-10 ],
[ 1.07193451e-08, 8.75000016e-01, -1.25000007e-01 ]
]),
'ffhq_to_arcface_128_v2': torch.tensor(
[
[ 8.50048894e-01, -1.29486822e-04, 1.90956388e-03 ],
[ 1.29486822e-04, 8.50048894e-01, 9.56254653e-02 ]
]),
'vgg_face_hq_to_arcface_128_v2': torch.tensor(
[
[ 1.01305414, -0.00140513, -0.00585911 ],
[ 0.00140513, 1.01305414, 0.11169602 ]
])
}
def warp_tensor(input_tensor : Tensor, warp_template : WarpTemplate) -> Tensor:
normed_warp_template = WARP_TEMPLATE_SET.get(warp_template).repeat(input_tensor.shape[0], 1, 1)
affine_grid = nn.functional.affine_grid(normed_warp_template.to(input_tensor.device), list(input_tensor.shape))
output_tensor = nn.functional.grid_sample(input_tensor, affine_grid, align_corners = False, padding_mode = 'reflection')
return output_tensor
def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : Padding) -> Embedding:
crop_tensor = warp_tensor(input_tensor, 'arcface_128_v2_to_arcface_112_v2')
crop_tensor = nn.functional.interpolate(crop_tensor, size = 112, mode = 'area')
crop_tensor[:, :, :padding[0], :] = 0
crop_tensor[:, :, 112 - padding[1]:, :] = 0
crop_tensor[:, :, :, :padding[2]] = 0
crop_tensor[:, :, :, 112 - padding[3]:] = 0
embedding = embedder(crop_tensor)
embedding = nn.functional.normalize(embedding, p = 2)
return embedding
def overlay_mask(input_tensor : Tensor, input_mask : Mask) -> Tensor:
overlay_tensor = torch.zeros(*input_tensor.shape, dtype = input_tensor.dtype, device = input_tensor.device)
overlay_tensor[:, 2, :, :] = 1
input_mask = input_mask.repeat(1, 3, 1, 1).clamp(0, 0.8)
output_tensor = input_tensor * (1 - input_mask) + overlay_tensor * input_mask
return output_tensor
+27
View File
@@ -0,0 +1,27 @@
import configparser
import torch
from torchvision import io
from .helper import calc_embedding
from .training import FaceSwapperTrainer
CONFIG_PARSER = configparser.ConfigParser()
CONFIG_PARSER.read('config.ini')
def infer() -> None:
config_generator_path = CONFIG_PARSER.get('inferencing', 'generator_path')
config_embedder_path = CONFIG_PARSER.get('inferencing', 'embedder_path')
config_source_path = CONFIG_PARSER.get('inferencing', 'source_path')
config_target_path = CONFIG_PARSER.get('inferencing', 'target_path')
config_output_path = CONFIG_PARSER.get('inferencing', 'output_path')
generator = FaceSwapperTrainer.load_from_checkpoint(config_generator_path, config_parser = CONFIG_PARSER, map_location = 'cpu').eval()
embedder = torch.jit.load(config_embedder_path, map_location = 'cpu').eval()
source_tensor = io.read_image(config_source_path)
target_tensor = io.read_image(config_target_path)
source_embedding = calc_embedding(embedder, source_tensor, (0, 0, 0, 0))
output_tensor, _ = generator(source_embedding, target_tensor)
io.write_jpeg(output_tensor, config_output_path)
View File
+35
View File
@@ -0,0 +1,35 @@
from configparser import ConfigParser
from typing import List
from torch import Tensor, nn
from ..networks.nld import NLD
class Discriminator(nn.Module):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config_num_discriminators = config_parser.getint('training.model.discriminator', 'num_discriminators')
self.config_parser = config_parser
self.discriminators = self.create_discriminators()
self.avg_pool = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = (1, 1), count_include_pad = False)
def create_discriminators(self) -> nn.ModuleList:
discriminators = nn.ModuleList()
for _ in range(self.config_num_discriminators):
discriminator = NLD(self.config_parser).sequences
discriminators.append(discriminator)
return discriminators
def forward(self, input_tensor : Tensor) -> List[Tensor]:
temp_tensor = input_tensor
output_tensors = []
for discriminator in self.discriminators:
output_tensor = discriminator(temp_tensor)
output_tensors.append(output_tensor)
temp_tensor = self.avg_pool(temp_tensor)
return output_tensors
+42
View File
@@ -0,0 +1,42 @@
from configparser import ConfigParser
from typing import Tuple
from torch import Tensor, nn
from ..networks.aad import AAD
from ..networks.masknet import MaskNet
from ..networks.unet import UNet
from ..types import Embedding, Feature, Mask
class Generator(nn.Module):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.encoder = UNet(config_parser)
self.generator = AAD(config_parser)
self.masker = MaskNet(config_parser)
self.encoder.apply(init_weight)
self.generator.apply(init_weight)
self.masker.apply(init_weight)
def forward(self, source_embedding : Embedding, target_tensor : Tensor, target_features : Tuple[Feature, ...]) -> Tuple[Tensor, Mask]:
output_tensor = self.generator(source_embedding, target_features)
target_feature = target_features[-1]
output_mask = self.masker(target_tensor, target_feature)
output_tensor = output_tensor * output_mask + target_tensor * (1 - output_mask)
return output_tensor, output_mask
def encode_features(self, input_tensor : Tensor) -> Tuple[Feature, ...]:
return self.encoder(input_tensor)
def init_weight(module : nn.Module) -> None:
if isinstance(module, nn.Linear):
module.weight.data.normal_(std = 0.001)
module.bias.data.zero_()
if isinstance(module, nn.Conv2d):
nn.init.xavier_normal_(module.weight.data)
if isinstance(module, nn.ConvTranspose2d):
nn.init.xavier_normal_(module.weight.data)
+186
View File
@@ -0,0 +1,186 @@
from configparser import ConfigParser
from typing import List, Tuple
import torch
from pytorch_msssim import ssim
from torch import Tensor, nn
from torchvision import transforms
from ..helper import calc_embedding
from ..types import EmbedderModule, FaceMaskerModule, Feature, GazerModule, Loss, Mask
class DiscriminatorLoss(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, discriminator_source_tensors : List[Tensor], discriminator_output_tensors : List[Tensor]) -> Loss:
positive_tensors = []
negative_tensors = []
for discriminator_source_tensor in discriminator_source_tensors:
positive_tensor = torch.relu(1 - discriminator_source_tensor).mean(dim = [ 1, 2, 3 ])
positive_tensors.append(positive_tensor)
for discriminator_output_tensor in discriminator_output_tensors:
negative_tensor = torch.relu(discriminator_output_tensor + 1).mean(dim = [ 1, 2, 3 ])
negative_tensors.append(negative_tensor)
positive_loss = torch.stack(positive_tensors).mean()
negative_loss = torch.stack(negative_tensors).mean()
discriminator_loss = (positive_loss + negative_loss) * 0.5
return discriminator_loss
class AdversarialLoss(nn.Module):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config_adversarial_weight = config_parser.getfloat('training.losses', 'adversarial_weight')
def forward(self, discriminator_output_tensors : List[Tensor]) -> Tuple[Loss, Loss]:
temp_tensors = []
for discriminator_output_tensor in discriminator_output_tensors:
temp_tensor = torch.relu(1 - discriminator_output_tensor).mean(dim = [ 1, 2, 3 ]).mean()
temp_tensors.append(temp_tensor)
adversarial_loss = torch.stack(temp_tensors).mean()
weighted_adversarial_loss = adversarial_loss * self.config_adversarial_weight
return adversarial_loss, weighted_adversarial_loss
class CycleLoss(nn.Module):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config_batch_size = config_parser.getint('training.loader', 'batch_size')
self.config_cycle_weight = config_parser.getfloat('training.losses', 'cycle_weight')
self.l1_loss = nn.L1Loss()
def forward(self, target_tensor : Tensor, cycle_tensor : Tensor, target_features : Tuple[Feature, ...], cycle_features : Tuple[Feature, ...]) -> Tuple[Loss, Loss]:
temp_tensors = []
for target_feature, output_feature in zip(target_features, cycle_features):
temp_tensor = torch.mean(torch.pow(output_feature - target_feature, 2).reshape(self.config_batch_size, -1), dim = 1).mean()
temp_tensors.append(temp_tensor)
feature_loss = torch.stack(temp_tensors).mean()
reconstruction_loss = self.l1_loss(target_tensor, cycle_tensor)
cycle_loss = (feature_loss + reconstruction_loss) * 0.5
weighted_feature_loss = cycle_loss * self.config_cycle_weight
return cycle_loss, weighted_feature_loss
class FeatureLoss(nn.Module):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config_batch_size = config_parser.getint('training.loader', 'batch_size')
self.config_feature_weight = config_parser.getfloat('training.losses', 'feature_weight')
def forward(self, target_features : Tuple[Feature, ...], output_features : Tuple[Feature, ...]) -> Tuple[Loss, Loss]:
temp_tensors = []
for target_feature, output_feature in zip(target_features, output_features):
temp_tensor = torch.mean(torch.pow(output_feature - target_feature, 2).reshape(self.config_batch_size, -1), dim = 1).mean()
temp_tensors.append(temp_tensor)
feature_loss = torch.stack(temp_tensors).mean() * 0.5
weighted_feature_loss = feature_loss * self.config_feature_weight
return feature_loss, weighted_feature_loss
class ReconstructionLoss(nn.Module):
def __init__(self, config_parser : ConfigParser, embedder : EmbedderModule) -> None:
super().__init__()
self.config_reconstruction_weight = config_parser.getfloat('training.losses', 'reconstruction_weight')
self.embedder = embedder
self.mse_loss = nn.MSELoss()
def forward(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss]:
with torch.no_grad():
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
target_embedding = calc_embedding(self.embedder, target_tensor, (0, 0, 0, 0))
has_similar_identity = torch.cosine_similarity(source_embedding, target_embedding) > 0.8
reconstruction_loss = torch.mean((source_tensor - target_tensor) ** 2, dim = (1, 2, 3))
reconstruction_loss = (reconstruction_loss * has_similar_identity).mean() * 0.5
data_range = float(torch.max(output_tensor) - torch.min(output_tensor))
visual_loss = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean()
reconstruction_loss = (reconstruction_loss + visual_loss) * 0.5
weighted_reconstruction_loss = reconstruction_loss * self.config_reconstruction_weight
return reconstruction_loss, weighted_reconstruction_loss
class IdentityLoss(nn.Module):
def __init__(self, config_parser : ConfigParser, embedder : EmbedderModule) -> None:
super().__init__()
self.config_identity_weight = config_parser.getfloat('training.losses', 'identity_weight')
self.embedder = embedder
def forward(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss]:
output_embedding = calc_embedding(self.embedder, output_tensor, (30, 0, 10, 10))
source_embedding = calc_embedding(self.embedder, source_tensor, (30, 0, 10, 10))
identity_loss = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean()
weighted_identity_loss = identity_loss * self.config_identity_weight
return identity_loss, weighted_identity_loss
class GazeLoss(nn.Module):
def __init__(self, config_parser : ConfigParser, gazer : GazerModule) -> None:
super().__init__()
self.config_gaze_weight = config_parser.getfloat('training.losses', 'gaze_weight')
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
self.gazer = gazer
self.l1_loss = nn.L1Loss()
def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss]:
output_pitch, output_yaw = self.detect_gaze(output_tensor)
target_pitch, target_yaw = self.detect_gaze(target_tensor)
pitch_loss = self.l1_loss(output_pitch, target_pitch)
yaw_loss = self.l1_loss(output_yaw, target_yaw)
gaze_loss = (pitch_loss + yaw_loss) * 0.5
weighted_gaze_loss = gaze_loss * self.config_gaze_weight
return gaze_loss, weighted_gaze_loss
def detect_gaze(self, input_tensor : Tensor) -> Tuple[Tensor, Tensor]:
crop_sizes = (torch.tensor([ 0.235, 0.875, 0.0625, 0.8 ]) * self.config_output_size).int()
crop_tensor = input_tensor[:, :, crop_sizes[0]:crop_sizes[1], crop_sizes[2]:crop_sizes[3]]
crop_tensor = (crop_tensor + 1) * 0.5
crop_tensor = transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ])(crop_tensor)
crop_tensor = nn.functional.interpolate(crop_tensor, size = 448, mode = 'bicubic')
with torch.no_grad():
pitch, yaw = self.gazer(crop_tensor)
return pitch, yaw
class MaskLoss(nn.Module):
def __init__(self, config_parser : ConfigParser, face_masker : FaceMaskerModule) -> None:
super().__init__()
self.config_mask_weight = config_parser.getfloat('training.losses', 'mask_weight')
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
self.face_masker = face_masker
self.mse_loss = nn.MSELoss()
def forward(self, target_tensor : Tensor, output_mask : Mask) -> Tuple[Loss, Loss]:
target_mask = self.calc_mask(target_tensor)
target_mask = target_mask.view(-1, self.config_output_size, self.config_output_size)
output_mask = output_mask.view(-1, self.config_output_size, self.config_output_size)
mask_loss = self.mse_loss(target_mask, output_mask)
weighted_mask_loss = mask_loss * self.config_mask_weight
return mask_loss, weighted_mask_loss
def calc_mask(self, target_tensor : Tensor) -> Tensor:
target_tensor = torch.nn.functional.interpolate(target_tensor, (256, 256), mode = 'bilinear')
target_tensor = (target_tensor.clip(-1, 1) + 1) * 0.5
with torch.no_grad():
output_tensor = self.face_masker(target_tensor)
output_tensor = output_tensor.clamp(0, 1)
output_tensor = torch.nn.functional.interpolate(output_tensor, (self.config_output_size, self.config_output_size), mode = 'bilinear')
return output_tensor
+191
View File
@@ -0,0 +1,191 @@
from configparser import ConfigParser
from typing import Tuple
import torch
from torch import Tensor, nn
from ..types import Embedding, Feature
class AAD(nn.Module):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config_source_channels = config_parser.getint('training.model.generator', 'source_channels')
self.config_output_channels = config_parser.getint('training.model.generator', 'output_channels')
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
self.config_num_blocks = config_parser.getint('training.model.generator', 'num_blocks')
self.pixel_shuffle_up_sample = PixelShuffleUpSample(self.config_source_channels, self.config_output_channels)
self.layers = self.create_layers()
def create_layers(self) -> nn.ModuleList:
layers = nn.ModuleList()
if self.config_output_size == 128:
layers.extend(
[
AdaptiveFeatureModulation(512, 512, 512, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(512, 512, 1024, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(512, 512, 512, self.config_source_channels, self.config_num_blocks)
])
if self.config_output_size == 256:
layers.extend(
[
AdaptiveFeatureModulation(1024, 1024, 1024, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 1024, 2048, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 1024, 1024, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks)
])
if self.config_output_size == 512:
layers.extend(
[
AdaptiveFeatureModulation(2048, 2048, 2048, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(2048, 2048, 4096, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(2048, 2048, 2048, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(2048, 1024, 1024, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks)
])
if self.config_output_size == 1024:
layers.extend(
[
AdaptiveFeatureModulation(4096, 4096, 4096, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(4096, 4096, 8192, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(4096, 4096, 4096, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(4096, 2048, 2048, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(2048, 1024, 1024, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks)
])
layers.extend(
[
AdaptiveFeatureModulation(512, 256, 256, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(256, 128, 128, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(128, 64, 64, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(64, 3, 64, self.config_source_channels, self.config_num_blocks)
])
return layers
def forward(self, source_embedding : Embedding, target_features : Tuple[Feature, ...]) -> Tensor:
temp_tensors = self.pixel_shuffle_up_sample(source_embedding)
for index, layer in enumerate(self.layers[:-1]):
target_feature = target_features[index]
temp_tensor = layer(temp_tensors, source_embedding, target_feature)
temp_tensors = nn.functional.interpolate(temp_tensor, scale_factor = 2, mode = 'bilinear', align_corners = False)
target_feature = target_features[-1]
temp_tensors = self.layers[-1](temp_tensors, source_embedding, target_feature)
output_tensor = torch.tanh(temp_tensors)
return output_tensor
class AdaptiveFeatureModulation(nn.Module):
def __init__(self, input_channels : int, output_channels : int, target_channels : int, source_channels : int, num_blocks : int) -> None:
super().__init__()
self.context_input_channels = input_channels
self.context_output_channels = output_channels
self.context_target_channels = target_channels
self.context_source_channels = source_channels
self.context_num_blocks = num_blocks
self.primary_layers = self.create_primary_layers()
self.shortcut_layers = self.create_shortcut_layers()
def create_primary_layers(self) -> nn.ModuleList:
primary_layers = nn.ModuleList()
for index in range(self.context_num_blocks):
primary_layers.extend(
[
FeatureModulation(self.context_input_channels, self.context_target_channels, self.context_source_channels),
nn.ReLU(inplace = True)
])
if index < self.context_num_blocks - 1:
primary_layers.append(nn.Conv2d(self.context_input_channels, self.context_input_channels, kernel_size = 3, padding = 1, bias = False))
else:
primary_layers.append(nn.Conv2d(self.context_input_channels, self.context_output_channels, kernel_size = 3, padding = 1, bias = False))
return primary_layers
def create_shortcut_layers(self) -> nn.ModuleList:
shortcut_layers = nn.ModuleList()
if self.context_input_channels > self.context_output_channels:
shortcut_layers.extend(
[
FeatureModulation(self.context_input_channels, self.context_target_channels, self.context_source_channels),
nn.ReLU(inplace = True),
nn.Conv2d(self.context_input_channels, self.context_output_channels, kernel_size = 3, padding = 1, bias = False)
])
return shortcut_layers
def forward(self, input_tensor : Tensor, source_embedding : Embedding, target_feature : Feature) -> Tensor:
primary_tensor = input_tensor
for primary_layer in self.primary_layers:
if isinstance(primary_layer, FeatureModulation):
primary_tensor = primary_layer(primary_tensor, source_embedding, target_feature)
else:
primary_tensor = primary_layer(primary_tensor)
if self.context_input_channels > self.context_output_channels:
shortcut_tensor = input_tensor
for shortcut_layer in self.shortcut_layers:
if isinstance(shortcut_layer, FeatureModulation):
shortcut_tensor = shortcut_layer(shortcut_tensor, source_embedding, target_feature)
else:
shortcut_tensor = shortcut_layer(shortcut_tensor)
input_tensor = shortcut_tensor
return primary_tensor + input_tensor
class FeatureModulation(nn.Module):
def __init__(self, input_channels : int, target_channels : int, source_channels : int) -> None:
super().__init__()
self.context_input_channels = input_channels
self.conv1 = nn.Conv2d(target_channels, input_channels, kernel_size = 1)
self.conv2 = nn.Conv2d(target_channels, input_channels, kernel_size = 1)
self.conv3 = nn.Conv2d(input_channels, 1, kernel_size = 1)
self.linear1 = nn.Linear(source_channels, input_channels)
self.linear2 = nn.Linear(source_channels, input_channels)
self.instance_norm = nn.InstanceNorm2d(input_channels)
def forward(self, input_tensor : Tensor, source_embedding : Embedding, target_feature : Feature) -> Tensor:
temp_tensor = self.instance_norm(input_tensor)
source_scale = self.linear2(source_embedding).reshape(temp_tensor.shape[0], self.context_input_channels, 1, 1).expand_as(temp_tensor)
source_shift = self.linear1(source_embedding).reshape(temp_tensor.shape[0], self.context_input_channels, 1, 1).expand_as(temp_tensor)
source_modulation = source_scale * temp_tensor + source_shift
target_scale = self.conv1(target_feature)
target_shift = self.conv2(target_feature)
target_modulation = target_scale * temp_tensor + target_shift
temp_mask = torch.sigmoid(self.conv3(temp_tensor))
output_tensor = (1 - temp_mask) * target_modulation + temp_mask * source_modulation
return output_tensor
class PixelShuffleUpSample(nn.Module):
def __init__(self, input_channels : int, output_channels : int) -> None:
super().__init__()
self.sequences = self.create_sequences(input_channels, output_channels)
@staticmethod
def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential:
return nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1),
nn.PixelShuffle(upscale_factor = 2)
)
def forward(self, input_tensor : Tensor) -> Tensor:
temp_tensor = input_tensor.view(input_tensor.shape[0], -1, 1, 1)
output_tensor = self.sequences(temp_tensor)
return output_tensor
+111
View File
@@ -0,0 +1,111 @@
from configparser import ConfigParser
import torch
from torch import Tensor, nn
from ..types import Feature, Mask
class MaskNet(nn.Module):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config_input_channels = config_parser.getint('training.model.masker', 'input_channels')
self.config_output_channels = config_parser.getint('training.model.masker', 'output_channels')
self.config_num_filters = config_parser.getint('training.model.masker', 'num_filters')
self.down_samples = self.create_down_samples(self.config_input_channels, self.config_num_filters)
self.up_samples = self.create_up_samples(self.config_num_filters)
self.bottleneck = BottleNeck(self.config_num_filters * 4)
self.conv = nn.Conv2d(self.config_num_filters, self.config_output_channels, kernel_size = 1)
self.sigmoid = nn.Sigmoid()
@staticmethod
def create_down_samples(input_channels : int, num_filters : int) -> nn.ModuleList:
return nn.ModuleList(
[
DownSample(input_channels, num_filters),
DownSample(num_filters, num_filters * 2),
DownSample(num_filters * 2, num_filters * 4)
])
@staticmethod
def create_up_samples(num_filters : int) -> nn.ModuleList:
return nn.ModuleList(
[
UpSample(num_filters * 4, num_filters * 2),
UpSample(num_filters * 2, num_filters),
UpSample(num_filters, num_filters)
])
def forward(self, input_tensor : Tensor, input_feature : Feature) -> Mask:
output_mask = torch.cat([ input_tensor, input_feature ], dim = 1)
for down_sample in self.down_samples:
output_mask = down_sample(output_mask)
output_mask = self.bottleneck(output_mask)
for up_sample in self.up_samples:
output_mask = up_sample(output_mask)
output_mask = self.conv(output_mask)
output_mask = self.sigmoid(output_mask)
return output_mask
class BottleNeck(nn.Module):
def __init__(self, num_filters : int):
super().__init__()
self.sequences = self.create_sequences(num_filters)
self.relu = nn.ReLU(inplace = True)
@staticmethod
def create_sequences(num_filters : int) -> nn.Sequential:
return nn.Sequential(
nn.Conv2d(num_filters, num_filters, kernel_size = 3, padding = 1, bias = False),
nn.BatchNorm2d(num_filters),
nn.ReLU(inplace = True),
nn.Conv2d(num_filters, num_filters, kernel_size = 3, padding = 1, bias = False),
nn.BatchNorm2d(num_filters),
nn.ReLU(inplace = True)
)
def forward(self, input_tensor : Tensor) -> Tensor:
output_tensor = self.sequences(input_tensor) + input_tensor
output_tensor = self.relu(output_tensor)
return output_tensor
class UpSample(nn.Module):
def __init__(self, input_channels : int, output_channels : int) -> None:
super().__init__()
self.sequences = self.create_sequences(input_channels, output_channels)
@staticmethod
def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size = 2, stride = 2),
nn.ReLU(inplace = True)
)
def forward(self, input_tensor : Tensor) -> Tensor:
output_tensor = self.sequences(input_tensor)
return output_tensor
class DownSample(nn.Module):
def __init__(self, input_channels : int, output_channels : int) -> None:
super().__init__()
self.sequences = self.create_sequences(input_channels, output_channels)
@staticmethod
def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential:
return nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1, bias = False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace = True),
nn.MaxPool2d(2)
)
def forward(self, input_tensor : Tensor) -> Tensor:
output_tensor = self.sequences(input_tensor)
return output_tensor
+48
View File
@@ -0,0 +1,48 @@
import math
from configparser import ConfigParser
from torch import Tensor, nn
class NLD(nn.Module):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config_input_channels = config_parser.getint('training.model.discriminator', 'input_channels')
self.config_num_filters = config_parser.getint('training.model.discriminator', 'num_filters')
self.config_kernel_size = config_parser.getint('training.model.discriminator', 'kernel_size')
self.config_num_layers = config_parser.getint('training.model.discriminator', 'num_layers')
self.layers = self.create_layers()
self.sequences = nn.Sequential(*self.layers)
def create_layers(self) -> nn.ModuleList:
padding = math.ceil((self.config_kernel_size - 1) / 2)
current_filters = self.config_num_filters
layers = nn.ModuleList(
[
nn.Conv2d(self.config_input_channels, current_filters, kernel_size = self.config_kernel_size, stride = 2, padding = padding),
nn.LeakyReLU(0.2, True)
])
for _ in range(1, self.config_num_layers):
previous_filters = current_filters
current_filters = min(current_filters * 2, 512)
layers +=\
[
nn.Conv2d(previous_filters, current_filters, kernel_size = self.config_kernel_size, stride = 2, padding = padding),
nn.InstanceNorm2d(current_filters),
nn.LeakyReLU(0.2, True)
]
previous_filters = current_filters
current_filters = min(current_filters * 2, 512)
layers +=\
[
nn.Conv2d(previous_filters, current_filters, kernel_size = self.config_kernel_size, padding = padding),
nn.InstanceNorm2d(current_filters),
nn.LeakyReLU(0.2, True),
nn.Conv2d(current_filters, 1, kernel_size = self.config_kernel_size, padding = padding)
]
return layers
def forward(self, input_tensor : Tensor) -> Tensor:
return self.sequences(input_tensor)
+157
View File
@@ -0,0 +1,157 @@
from configparser import ConfigParser
from typing import Tuple
import torch
from torch import Tensor, nn
from ..types import Feature
class UNet(nn.Module):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
self.down_samples = self.create_down_samples()
self.up_samples = self.create_up_samples()
def create_down_samples(self) -> nn.ModuleList:
down_samples = nn.ModuleList(
[
DownSample(3, 32),
DownSample(32, 64),
DownSample(64, 128),
DownSample(128, 256),
DownSample(256, 512)
])
if self.config_output_size == 128:
down_samples.extend(
[
DownSample(512, 512)
])
if self.config_output_size == 256:
down_samples.extend(
[
DownSample(512, 1024),
DownSample(1024, 1024)
])
if self.config_output_size == 512:
down_samples.extend(
[
DownSample(512, 1024),
DownSample(1024, 2048),
DownSample(2048, 2048)
])
if self.config_output_size == 1024:
down_samples.extend(
[
DownSample(512, 1024),
DownSample(1024, 2048),
DownSample(2048, 4096),
DownSample(4096, 4096)
])
return down_samples
def create_up_samples(self) -> nn.ModuleList:
up_samples = nn.ModuleList()
if self.config_output_size == 128:
up_samples.extend(
[
UpSample(512, 512)
])
if self.config_output_size == 256:
up_samples.extend(
[
UpSample(1024, 1024),
UpSample(2048, 512)
])
if self.config_output_size == 512:
up_samples.extend(
[
UpSample(2048, 2048),
UpSample(4096, 1024),
UpSample(2048, 512)
])
if self.config_output_size == 1024:
up_samples.extend(
[
UpSample(4096, 4096),
UpSample(8192, 2048),
UpSample(4096, 1024),
UpSample(2048, 512)
])
up_samples.extend(
[
UpSample(1024, 256),
UpSample(512, 128),
UpSample(256, 64),
UpSample(128, 32)
])
return up_samples
def forward(self, target_tensor : Tensor) -> Tuple[Feature, ...]:
down_features = []
up_features = []
temp_feature = target_tensor
for down_sample in self.down_samples:
temp_feature = down_sample(temp_feature)
down_features.append(temp_feature)
bottleneck_feature = down_features[-1]
temp_feature = bottleneck_feature
for index, up_sample in enumerate(self.up_samples):
skip_tensor = down_features[-(index + 2)]
temp_feature = up_sample(temp_feature, skip_tensor)
up_features.append(temp_feature)
final_feature = nn.functional.interpolate(temp_feature, scale_factor = 2, mode = 'bilinear', align_corners = False)
return bottleneck_feature, *up_features, final_feature
class UpSample(nn.Module):
def __init__(self, input_channels : int, output_channels : int) -> None:
super().__init__()
self.sequences = self.create_sequences(input_channels, output_channels)
@staticmethod
def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False),
nn.BatchNorm2d(output_channels),
nn.LeakyReLU(0.1, inplace = True)
)
def forward(self, input_tensor : Tensor, skip_tensor : Tensor) -> Tensor:
output_tensor = self.sequences(input_tensor)
output_tensor = torch.cat((output_tensor, skip_tensor), dim = 1)
return output_tensor
class DownSample(nn.Module):
def __init__(self, input_channels : int, output_channels : int) -> None:
super().__init__()
self.sequences = self.create_sequences(input_channels, output_channels)
@staticmethod
def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential:
return nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False),
nn.BatchNorm2d(output_channels),
nn.LeakyReLU(0.1, inplace = True)
)
def forward(self, input_tensor : Tensor) -> Tensor:
output_tensor = self.sequences(input_tensor)
return output_tensor
+248
View File
@@ -0,0 +1,248 @@
import os
import warnings
from configparser import ConfigParser
from typing import List, Tuple
import torch
import torchvision
from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from torch import Tensor, nn
from torch.utils.data import ConcatDataset, Dataset, random_split
from torchdata.stateful_dataloader import StatefulDataLoader
from .dataset import DynamicDataset
from .helper import calc_embedding, overlay_mask
from .models.discriminator import Discriminator
from .models.generator import Generator
from .models.loss import AdversarialLoss, CycleLoss, DiscriminatorLoss, FeatureLoss, GazeLoss, IdentityLoss, MaskLoss, ReconstructionLoss
from .types import Batch, Embedding, Mask, OptimizerSet
warnings.filterwarnings('ignore', category = UserWarning, module = 'torch')
CONFIG_PARSER = ConfigParser()
CONFIG_PARSER.read('config.ini')
class FaceSwapperTrainer(LightningModule):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config_generator_embedder_path = config_parser.get('training.model', 'generator_embedder_path')
self.config_loss_embedder_path = config_parser.get('training.model', 'loss_embedder_path')
self.config_gazer_path = config_parser.get('training.model', 'gazer_path')
self.config_face_masker_path = config_parser.get('training.model', 'face_masker_path')
self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size')
self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate')
self.config_preview_frequency = config_parser.getint('training.trainer', 'preview_frequency')
self.generator_embedder = torch.jit.load(self.config_generator_embedder_path, map_location = 'cpu').eval()
self.loss_embedder = torch.jit.load(self.config_loss_embedder_path, map_location = 'cpu').eval()
self.gazer = torch.jit.load(self.config_gazer_path, map_location = 'cpu').eval()
self.face_masker = torch.jit.load(self.config_face_masker_path, map_location ='cpu').eval()
self.generator = Generator(config_parser)
self.discriminator = Discriminator(config_parser)
self.discriminator_loss = DiscriminatorLoss()
self.adversarial_loss = AdversarialLoss(config_parser)
self.cycle_loss = CycleLoss(config_parser)
self.feature_loss = FeatureLoss(config_parser)
self.reconstruction_loss = ReconstructionLoss(config_parser, self.loss_embedder)
self.identity_loss = IdentityLoss(config_parser, self.loss_embedder)
self.gaze_loss = GazeLoss(config_parser, self.gazer)
self.mask_loss = MaskLoss(config_parser, self.face_masker)
self.automatic_optimization = False
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Mask]:
with torch.no_grad():
generator_target_features = self.generator.encode_features(target_tensor)
output_tensor, output_mask = self.generator(source_embedding, target_tensor, generator_target_features)
return output_tensor, output_mask
def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet]:
generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
generator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(generator_optimizer, T_0 = 300, T_mult = 2)
discriminator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(discriminator_optimizer, T_0 = 300, T_mult = 2)
generator_config =\
{
'optimizer': generator_optimizer,
'lr_scheduler':
{
'scheduler': generator_scheduler,
'interval': 'step'
}
}
discriminator_config =\
{
'optimizer': discriminator_optimizer,
'lr_scheduler':
{
'scheduler': discriminator_scheduler,
'interval': 'step'
}
}
return generator_config, discriminator_config
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
source_tensor, target_tensor = batch
do_update = (batch_index + 1) % self.config_accumulate_size == 0
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
source_embedding = calc_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0))
target_embedding = calc_embedding(self.generator_embedder, target_tensor, (0, 0, 0, 0))
generator_target_features = self.generator.encode_features(target_tensor)
generator_output_tensor, generator_output_mask = self.generator(source_embedding, target_tensor, generator_target_features)
generator_output_features = self.generator.encode_features(generator_output_tensor)
cycle_output_tensor, cycle_output_mask = self.generator(target_embedding, generator_output_tensor, generator_output_features)
cycle_output_features = self.generator.encode_features(cycle_output_tensor)
discriminator_output_tensors = self.discriminator(generator_output_tensor)
adversarial_loss, weighted_adversarial_loss = self.adversarial_loss(discriminator_output_tensors)
cycle_loss, weighted_cycle_loss = self.cycle_loss(target_tensor, cycle_output_tensor, generator_target_features, cycle_output_features)
feature_loss, weighted_feature_loss = self.feature_loss(generator_target_features, generator_output_features)
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss(source_tensor, target_tensor, generator_output_tensor)
identity_loss, weighted_identity_loss = self.identity_loss(generator_output_tensor, source_tensor)
gaze_loss, weighted_gaze_loss = self.gaze_loss(target_tensor, generator_output_tensor)
mask_loss, weighted_mask_loss = self.mask_loss(target_tensor, generator_output_mask)
generator_loss = weighted_adversarial_loss + weighted_cycle_loss + weighted_feature_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_gaze_loss + weighted_mask_loss
discriminator_source_tensors = self.discriminator(source_tensor)
discriminator_output_tensors = self.discriminator(generator_output_tensor.detach())
discriminator_loss = self.discriminator_loss(discriminator_source_tensors, discriminator_output_tensors)
self.toggle_optimizer(generator_optimizer)
self.manual_backward(generator_loss)
if do_update:
generator_optimizer.step()
generator_optimizer.zero_grad()
self.untoggle_optimizer(generator_optimizer)
self.toggle_optimizer(discriminator_optimizer)
self.manual_backward(discriminator_loss)
if do_update:
discriminator_optimizer.step()
discriminator_optimizer.zero_grad()
self.untoggle_optimizer(discriminator_optimizer)
if self.global_step % self.config_preview_frequency == 0:
self.generate_preview(source_tensor, target_tensor, generator_output_tensor, generator_output_mask)
self.log('generator_loss', generator_loss, prog_bar = True)
self.log('discriminator_loss', discriminator_loss, prog_bar = True)
self.log('adversarial_loss', adversarial_loss)
self.log('cycle_loss', cycle_loss)
self.log('feature_loss', feature_loss)
self.log('reconstruction_loss', reconstruction_loss)
self.log('identity_loss', identity_loss)
self.log('gaze_loss', gaze_loss)
self.log('mask_loss', mask_loss)
return generator_loss
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
source_tensor, target_tensor = batch
source_embedding = calc_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0))
output_tensor, _ = self.forward(source_embedding, target_tensor)
output_embedding = calc_embedding(self.generator_embedder, output_tensor, (0, 0, 0, 0))
validation_score = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5
self.log('validation_score', validation_score, sync_dist = True, prog_bar = True)
return validation_score
def generate_preview(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor, output_mask : Mask) -> None:
preview_limit = 8
preview_cells = []
overlay_tensor = overlay_mask(output_tensor, output_mask)
for source_tensor, target_tensor, output_tensor, overlay_tensor in zip(source_tensor[:preview_limit], target_tensor[:preview_limit], output_tensor[:preview_limit], overlay_tensor[:preview_limit]):
preview_cell = torch.cat([ source_tensor, target_tensor, output_tensor, overlay_tensor ], dim = 2)
preview_cells.append(preview_cell)
preview_cells = torch.cat(preview_cells, dim = 1).unsqueeze(0)
preview_grid = torchvision.utils.make_grid(preview_cells, normalize = True, scale_each = True)
self.logger.experiment.add_image('preview', preview_grid, self.global_step) # type:ignore[attr-defined]
def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]:
config_batch_size = CONFIG_PARSER.getint('training.loader', 'batch_size')
config_num_workers = CONFIG_PARSER.getint('training.loader', 'num_workers')
training_dataset, validate_dataset = split_dataset(dataset)
training_loader = StatefulDataLoader(training_dataset, batch_size = config_batch_size, shuffle = True, num_workers = config_num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
validation_loader = StatefulDataLoader(validate_dataset, batch_size = config_batch_size, shuffle = False, num_workers = config_num_workers, pin_memory = True, persistent_workers = True)
return training_loader, validation_loader
def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[Tensor]]:
config_split_ratio = CONFIG_PARSER.getfloat('training.loader', 'split_ratio')
dataset_size = len(dataset) # type:ignore[arg-type]
training_size = int(dataset_size * config_split_ratio)
validation_size = int(dataset_size - training_size)
training_dataset, validate_dataset = random_split(dataset, [ training_size, validation_size ])
return training_dataset, validate_dataset
def prepare_datasets(config_parser : ConfigParser) -> List[Dataset[Tensor]]:
datasets = []
for config_section in config_parser.sections():
if config_section.startswith('training.dataset'):
current_config_parser = ConfigParser()
current_config_parser.add_section('training.dataset')
for key, value in config_parser.items(config_section):
current_config_parser.set('training.dataset', key, value)
datasets.append(DynamicDataset(current_config_parser))
return datasets
def create_trainer() -> Trainer:
config_max_epochs = CONFIG_PARSER.getint('training.trainer', 'max_epochs')
config_strategy = CONFIG_PARSER.get('training.trainer', 'strategy')
config_precision = CONFIG_PARSER.get('training.trainer', 'precision')
config_logger_path = CONFIG_PARSER.get('training.trainer', 'logger_path')
config_logger_name = CONFIG_PARSER.get('training.trainer', 'logger_name')
config_directory_path = CONFIG_PARSER.get('training.output', 'directory_path')
config_file_pattern = CONFIG_PARSER.get('training.output', 'file_pattern')
logger = TensorBoardLogger(config_logger_path, config_logger_name)
return Trainer(
logger = logger,
log_every_n_steps = 10,
max_epochs = config_max_epochs,
strategy = config_strategy,
precision = config_precision,
callbacks =
[
ModelCheckpoint(
monitor = 'generator_loss',
dirpath = config_directory_path,
filename = config_file_pattern,
every_n_train_steps = 1000,
save_top_k = 3,
save_last = True
)
],
val_check_interval = 1000
)
def train() -> None:
config_resume_path = CONFIG_PARSER.get('training.output', 'resume_path')
if torch.cuda.is_available():
torch.set_float32_matmul_precision('high')
dataset = ConcatDataset(prepare_datasets(CONFIG_PARSER))
training_loader, validation_loader = create_loaders(dataset)
face_swapper_trainer = FaceSwapperTrainer(CONFIG_PARSER)
trainer = create_trainer()
if os.path.isfile(config_resume_path):
trainer.fit(face_swapper_trainer, training_loader, validation_loader, ckpt_path = config_resume_path)
else:
trainer.fit(face_swapper_trainer, training_loader, validation_loader)
+24
View File
@@ -0,0 +1,24 @@
from typing import Any, Dict, Literal, Tuple, TypeAlias
from torch import Tensor
from torch.nn import Module
Batch : TypeAlias = Tuple[Tensor, Tensor]
BatchMode = Literal['equal', 'same', 'different']
Feature : TypeAlias = Tensor
Embedding : TypeAlias = Tensor
Mask : TypeAlias = Tensor
Loss : TypeAlias = Tensor
Padding : TypeAlias = Tuple[int, int, int, int]
GeneratorModule : TypeAlias = Module
EmbedderModule : TypeAlias = Module
GazerModule : TypeAlias = Module
FaceMaskerModule : TypeAlias = Module
OptimizerSet : TypeAlias = Any
WarpTemplate = Literal['arcface_128_v2_to_arcface_112_v2', 'ffhq_to_arcface_128_v2', 'vgg_face_hq_to_arcface_128_v2']
WarpTemplateSet : TypeAlias = Dict[WarpTemplate, Tensor]
+57
View File
@@ -0,0 +1,57 @@
from configparser import ConfigParser
import pytest
import torch
from face_swapper.src.networks.aad import AAD
from face_swapper.src.networks.masknet import MaskNet
from face_swapper.src.networks.unet import UNet
@pytest.mark.parametrize('output_size', [ 128, 256, 512 ])
def test_aad_with_unet(output_size : int) -> None:
config_parser = ConfigParser()
config_parser.read_dict(
{
'training.model.generator':
{
'source_channels': '512',
'output_channels': str(output_size * 16),
'output_size': str(output_size),
'num_blocks': '2'
}
})
encoder = UNet(config_parser).eval()
generator = AAD(config_parser).eval()
source_tensor = torch.randn(1, 512)
target_tensor = torch.randn(1, 3, output_size, output_size)
target_features = encoder(target_tensor)
output_tensor = generator(source_tensor, target_features)
assert output_tensor.shape == (1, 3, output_size, output_size)
@pytest.mark.parametrize('output_size', [ 128, 256, 512 ])
def test_mask_net(output_size : int) -> None:
config_parser = ConfigParser()
config_parser.read_dict(
{
'training.model.masker':
{
'input_channels': '67',
'output_channels': '1',
'num_filters': '16'
}
})
masker = MaskNet(config_parser).eval()
target_tensor = torch.randn(1, 3, output_size, output_size)
target_feature = torch.randn(1, 64, output_size, output_size)
output_mask = masker(target_tensor, target_feature)
assert output_mask.shape == (1, 1, output_size, output_size)
+6
View File
@@ -0,0 +1,6 @@
#!/usr/bin/env python3
from src.training import train
if __name__ == '__main__':
train()
+1
View File
@@ -5,3 +5,4 @@ disallow_untyped_calls = True
disallow_untyped_defs = True
ignore_missing_imports = True
strict_optional = False
explicit_package_bases = True
+9 -5
View File
@@ -1,6 +1,10 @@
lightning==2.4.0
numpy==1.26.4
--extra-index-url https://download.pytorch.org/whl/cu124
albumentations==2.0.5
lightning==2.5.1
onnx==1.17.0
onnxruntime==1.20.0
opencv-python==4.10.0.84
mxnet==1.9.1
onnxruntime==1.21.0
pytorch-msssim==1.0.0
torch==2.6.0
torchdata==0.11.0
torchvision==0.21.0
tensorboard==2.19.0