diff --git a/.flake8 b/.flake8 index a840286..544f9ff 100644 --- a/.flake8 +++ b/.flake8 @@ -1,7 +1,5 @@ [flake8] -select = E3, E4, F, I1, I2 +select = E22, E23, E24, E27, E3, E4, E7, F, I1, I2 plugins = flake8-import-order -application_import_names = arcface_converter +application_import_names = embedding_converter, face_swapper import-order-style = pycharm -per-file-ignores = preparing.py:E402 - diff --git a/.github/preview_arcface_converter.png b/.github/previews/embedding_converter.png similarity index 100% rename from .github/preview_arcface_converter.png rename to .github/previews/embedding_converter.png diff --git a/.github/previews/face_swapper.png b/.github/previews/face_swapper.png new file mode 100644 index 0000000..de5022c Binary files /dev/null and b/.github/previews/face_swapper.png differ diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b73fedb..c2e3cd2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,12 +8,24 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 - - name: Set up Python 3.10 + - name: Set up Python 3.12 uses: actions/setup-python@v5 with: - python-version: '3.10' + python-version: '3.12' - run: pip install flake8 - run: pip install flake8-import-order - run: pip install mypy - - run: flake8 arcface_converter - - run: mypy arcface_converter + - run: flake8 embedding_converter face_swapper + - run: mypy embedding_converter face_swapper + test: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Set up Python 3.12 + uses: actions/setup-python@v5 + with: + python-version: '3.12' + - run: pip install torch torchvision + - run: pip install pytest + - run: PYTHONPATH=/home/runner/work/facefusion-labs/facefusion-labs pytest diff --git a/.gitignore b/.gitignore index 8ee9a7e..454001b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,10 @@ __pycache__ +.assets +.datasets .idea +.inputs +.exports +.logs +.models +.outputs .vscode diff --git a/arcface_converter/LICENSE.md b/arcface_converter/LICENSE.md deleted file mode 100644 index aae2360..0000000 --- a/arcface_converter/LICENSE.md +++ /dev/null @@ -1,3 +0,0 @@ -MIT license - -Copyright (c) 2024 Henry Ruhs diff --git a/arcface_converter/README.md b/arcface_converter/README.md deleted file mode 100644 index dbc3e1a..0000000 --- a/arcface_converter/README.md +++ /dev/null @@ -1,93 +0,0 @@ -ArcFace Converter -================= - -> Convert face embeddings between various ArcFace models. - -![License](https://img.shields.io/badge/license-MIT-green) - - -Preview -------- - -![Preview](https://raw.githubusercontent.com/facefusion/facefusion-labs/master/.github/preview_arcface_converter.png?sanitize=true) - - -Installation ------------- - -``` -pip install -r requirements.txt -``` - - -Example -------- - -This example utilizes the MegaFace dataset to train an ArcFace Converter for SimSwap. - -``` -[preparing.dataset] -dataset_path = datasets/megaface/train.rec -crop_size = 112 -process_limit = 650000 - -[preparing.model] -source_path = models/arcface_w600k_r50.onnx -target_path = models/arcface_simswap.onnx - -[preparing.input] -directory_path = inputs -source_path = inputs/arcface_w600k_r50.npy -target_path = inputs/arcface_simswap.npy - -[training.loader] -split_ratio = 0.8 -batch_size = 51200 -num_workers = 8 - -[training.trainer] -max_epochs = 4096 - -[training.output] -directory_path = outputs -file_pattern = arcface_converter_simswap_{epoch:02d}_{val_loss:.4f} - -[exporting] -directory_path = exports -source_path = outputs/last.ckpt -target_path = exports/arcface_converter_simswap.onnx -opset_version = 15 - -[execution] -providers = CUDAExecutionProvider -``` - - -Preparing ---------- - -Prepare the face embedding pairs. - -``` -python prepare.py -``` - - -Training --------- - -Train the ArcFace converter model. - -``` -python train.py -``` - - -Exporting ---------- - -Export the model to ONNX. - -``` -python export.py -``` diff --git a/arcface_converter/src/exporting.py b/arcface_converter/src/exporting.py deleted file mode 100644 index c5d9693..0000000 --- a/arcface_converter/src/exporting.py +++ /dev/null @@ -1,22 +0,0 @@ -import configparser -from os import makedirs - -import torch - -from .training import ArcFaceConverterTrainer - -CONFIG = configparser.ConfigParser() -CONFIG.read('config.ini') - - -def export() -> None: - directory_path = CONFIG.get('exporting', 'directory_path') - source_path = CONFIG.get('exporting', 'source_path') - target_path = CONFIG.get('exporting', 'target_path') - opset_version = CONFIG.getint('exporting', 'opset_version') - - makedirs(directory_path, exist_ok = True) - model = ArcFaceConverterTrainer.load_from_checkpoint(source_path, map_location = 'cpu') - model.eval() - input_tensor = torch.randn(1, 512) - torch.onnx.export(model, input_tensor, target_path, input_names = [ 'input' ], output_names = [ 'output' ], opset_version = opset_version) diff --git a/arcface_converter/src/model.py b/arcface_converter/src/model.py deleted file mode 100644 index cc6ecd9..0000000 --- a/arcface_converter/src/model.py +++ /dev/null @@ -1,21 +0,0 @@ -import torch -import torch.nn as nn -from torch import Tensor - - -class ArcFaceConverter(nn.Module): - def __init__(self) -> None: - super(ArcFaceConverter, self).__init__() - self.fc1 = nn.Linear(512, 1024) - self.fc2 = nn.Linear(1024, 2048) - self.fc3 = nn.Linear(2048, 1024) - self.fc4 = nn.Linear(1024, 512) - self.activation = nn.LeakyReLU() - - def forward(self, inputs : Tensor) -> Tensor: - norm_inputs = inputs / torch.norm(inputs) - outputs = self.activation(self.fc1(norm_inputs)) - outputs = self.activation(self.fc2(outputs)) - outputs = self.activation(self.fc3(outputs)) - outputs = self.fc4(outputs) - return outputs diff --git a/arcface_converter/src/preparing.py b/arcface_converter/src/preparing.py deleted file mode 100644 index 60151d0..0000000 --- a/arcface_converter/src/preparing.py +++ /dev/null @@ -1,79 +0,0 @@ -import configparser -from os import makedirs -from os.path import isfile -from typing import List - -import numpy -numpy.bool = numpy.bool_ -from mxnet.io import ImageRecordIter -from onnxruntime import InferenceSession -from tqdm import tqdm - -from .typing import Embedding, EmbeddingPairs, VisionFrame - -CONFIG = configparser.ConfigParser() -CONFIG.read('config.ini') - - -def prepare_crop_vision_frame(crop_vision_frame : VisionFrame) -> VisionFrame: - crop_vision_frame = crop_vision_frame.astype(numpy.float32) / 255 - crop_vision_frame = (crop_vision_frame - 0.5) * 2 - return crop_vision_frame - - -def create_inference_session(model_path : str, execution_providers : List[str]) -> InferenceSession: - inference_session = InferenceSession(model_path, providers = execution_providers) - return inference_session - - -def forward(inference_session : InferenceSession, crop_vision_frame : VisionFrame) -> Embedding: - embedding = inference_session.run(None, - { - 'input': crop_vision_frame - })[0] - - return embedding - - -def process_embeddings(dataset_reader : ImageRecordIter, source_inference_session : InferenceSession, target_inference_session : InferenceSession) -> EmbeddingPairs: - dataset_process_limit = CONFIG.getint('preparing.dataset', 'process_limit') - embedding_pairs = [] - - with tqdm(total = dataset_process_limit) as progress: - for batch in dataset_reader: - crop_vision_frame = batch.data[0].asnumpy() - crop_vision_frame = prepare_crop_vision_frame(crop_vision_frame) - source_embedding = forward(source_inference_session, crop_vision_frame) - target_embedding = forward(target_inference_session, crop_vision_frame) - embedding_pairs.append([ source_embedding, target_embedding ]) - progress.update() - - if progress.n == dataset_process_limit: - return numpy.concatenate(embedding_pairs, axis = 1).T - - return numpy.concatenate(embedding_pairs, axis = 1).T - - -def prepare() -> None: - dataset_path = CONFIG.get('preparing.dataset', 'dataset_path') - dataset_crop_size = CONFIG.getint('preparing.dataset', 'crop_size') - model_source_path = CONFIG.get('preparing.model', 'source_path') - model_target_path = CONFIG.get('preparing.model', 'target_path') - input_directory_path = CONFIG.get('preparing.input', 'directory_path') - input_source_path = CONFIG.get('preparing.input', 'source_path') - input_target_path = CONFIG.get('preparing.input', 'target_path') - execution_providers = CONFIG.get('execution', 'providers').split(' ') - - makedirs(input_directory_path, exist_ok = True) - if isfile(dataset_path) and isfile(model_source_path) and isfile(model_target_path): - dataset_reader = ImageRecordIter( - path_imgrec = dataset_path, - data_shape = (3, dataset_crop_size, dataset_crop_size), - batch_size = 1, - shuffle = False - ) - source_inference_session = create_inference_session(model_source_path, execution_providers) - target_inference_session = create_inference_session(model_target_path, execution_providers) - embedding_pairs = process_embeddings(dataset_reader, source_inference_session, target_inference_session) - numpy.save(input_source_path, embedding_pairs[..., 0].T) - numpy.save(input_target_path, embedding_pairs[..., 1].T) diff --git a/arcface_converter/src/training.py b/arcface_converter/src/training.py deleted file mode 100644 index 149bf43..0000000 --- a/arcface_converter/src/training.py +++ /dev/null @@ -1,116 +0,0 @@ -import configparser -from typing import Any, Tuple - -import numpy -import pytorch_lightning -import torch -from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.tuner.tuning import Tuner -from torch import Tensor -from torch.utils.data import DataLoader, Dataset, TensorDataset, random_split - -from .model import ArcFaceConverter -from .typing import Batch, Loader - -CONFIG = configparser.ConfigParser() -CONFIG.read('config.ini') - - -class ArcFaceConverterTrainer(pytorch_lightning.LightningModule): - def __init__(self) -> None: - super(ArcFaceConverterTrainer, self).__init__() - self.model = ArcFaceConverter() - self.loss_fn = torch.nn.MSELoss() - self.lr = 0.001 - - def forward(self, source_embedding : Tensor) -> Tensor: - return self.model(source_embedding) - - def training_step(self, batch : Batch, batch_index : int) -> Tensor: - source_embedding, target_embedding = batch - output_embedding = self(source_embedding) - loss = self.loss_fn(output_embedding, target_embedding) - self.log('train_loss', loss, prog_bar = True, logger = True) - return loss - - def validation_step(self, batch : Batch, batch_index : int) -> Tensor: - source_embedding, target_embedding = batch - output_embedding = self(source_embedding) - loss = self.loss_fn(output_embedding, target_embedding) - self.log('val_loss', loss, prog_bar = True, logger = True) - return loss - - def configure_optimizers(self) -> Any: - optimizer = torch.optim.Adam(self.parameters(), lr = self.lr) - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) - - return\ - { - 'optimizer': optimizer, - 'lr_scheduler': - { - 'scheduler': scheduler, - 'monitor': 'train_loss', - 'interval': 'epoch', - 'frequency': 1 - } - } - - -def create_loaders() -> Tuple[Loader, Loader]: - loader_batch_size = CONFIG.getint('training.loader', 'batch_size') - loader_num_workers = CONFIG.getint('training.loader', 'num_workers') - - training_dataset, validate_dataset = split_dataset() - training_loader = DataLoader(training_dataset, batch_size = loader_batch_size, num_workers = loader_num_workers, shuffle = True, pin_memory = True) - validation_loader = DataLoader(validate_dataset, batch_size = loader_batch_size, num_workers = loader_num_workers, shuffle = False, pin_memory = True) - return training_loader, validation_loader - - -def split_dataset() -> Tuple[Dataset[Any], Dataset[Any]]: - input_source_path = CONFIG.get('preparing.input', 'source_path') - input_target_path = CONFIG.get('preparing.input', 'target_path') - loader_split_ratio = CONFIG.getfloat('training.loader', 'split_ratio') - - source_input = torch.from_numpy(numpy.load(input_source_path)).float() - target_input = torch.from_numpy(numpy.load(input_target_path)).float() - dataset = TensorDataset(source_input, target_input) - - dataset_size = len(dataset) - training_size = int(loader_split_ratio * len(dataset)) - validation_size = int(dataset_size - training_size) - training_dataset, validate_dataset = random_split(dataset, [ training_size, validation_size ]) - return training_dataset, validate_dataset - - -def create_trainer() -> Trainer: - trainer_max_epochs = CONFIG.getint('training.trainer', 'max_epochs') - output_directory_path = CONFIG.get('training.output', 'directory_path') - output_file_pattern = CONFIG.get('training.output', 'file_pattern') - - return Trainer( - max_epochs = trainer_max_epochs, - callbacks = - [ - ModelCheckpoint( - monitor = 'train_loss', - dirpath = output_directory_path, - filename = output_file_pattern, - every_n_epochs = 10, - save_top_k = 3, - save_last = True - ) - ], - enable_progress_bar = True, - log_every_n_steps = 2 - ) - - -def train() -> None: - trainer = create_trainer() - training_loader, validation_loader = create_loaders() - model = ArcFaceConverterTrainer() - tuner = Tuner(trainer) - tuner.lr_find(model, training_loader, validation_loader) - trainer.fit(model, training_loader, validation_loader) diff --git a/arcface_converter/src/typing.py b/arcface_converter/src/typing.py deleted file mode 100644 index faeeec2..0000000 --- a/arcface_converter/src/typing.py +++ /dev/null @@ -1,13 +0,0 @@ -from typing import Any, Tuple - -from numpy.typing import NDArray -from torch import Tensor -from torch.utils.data import DataLoader - -Batch = Tuple[Tensor, Tensor] -Loader = DataLoader[Tuple[Tensor, ...]] - -Embedding = NDArray[Any] -EmbeddingPairs = NDArray[Any] -FaceLandmark5 = NDArray[Any] -VisionFrame = NDArray[Any] diff --git a/embedding_converter/LICENSE.md b/embedding_converter/LICENSE.md new file mode 100644 index 0000000..b31052b --- /dev/null +++ b/embedding_converter/LICENSE.md @@ -0,0 +1,3 @@ +OpenRAIL-MS license + +Copyright (c) 2025 Henry Ruhs diff --git a/embedding_converter/README.md b/embedding_converter/README.md new file mode 100644 index 0000000..3f940c1 --- /dev/null +++ b/embedding_converter/README.md @@ -0,0 +1,96 @@ +Embedding Converter +=================== + +> Convert face embeddings between various models. + +![License](https://img.shields.io/badge/license-OpenRAIL--MS-green) + + +Preview +------- + +![Preview](https://raw.githubusercontent.com/facefusion/facefusion-labs/next/.github/previews/embedding_converter.png?sanitize=true) + + +Installation +------------ + +``` +pip install -r requirements.txt +``` + + +Setup +----- + +This `config.ini` utilizes the MegaFace dataset to train the Embedding Converter for SimSwap. + +``` +[training.dataset] +file_pattern = .datasets/megaface/**/*.jpg +``` + +``` +[training.loader] +batch_size = 256 +num_workers = 8 +split_ratio = 0.95 +``` + +``` +[training.model] +source_path = .models/arcface_w600k_r50.pt +target_path = .models/arcface_simswap.pt +``` + +``` +[training.trainer] +learning_rate = 0.001 +max_epochs = 4096 +strategy = auto +precision = 16-mixed +logger_path = .logs +logger_name = arcface_converter_simswap +``` + +``` +[training.output] +directory_path = .outputs +file_pattern = arcface_converter_simswap_{epoch}_{step} +resume_path = .outputs/last.ckpt +``` + +``` +[exporting] +directory_path = .exports +source_path = .outputs/last.ckpt +target_path = .exports/arcface_converter_simswap.onnx +ir_version = 10 +opset_version = 15 +``` + + +Training +-------- + +Train the Embedding Converter model. + +``` +python train.py +``` + +Launch the TensorBoard to monitor the training. + +``` +tensorboard --logdir=.logs +``` + + +Exporting +--------- + +Export the model to ONNX. + +``` +python export.py +``` diff --git a/arcface_converter/__init__.py b/embedding_converter/__init__.py similarity index 100% rename from arcface_converter/__init__.py rename to embedding_converter/__init__.py diff --git a/arcface_converter/config.ini b/embedding_converter/config.ini similarity index 59% rename from arcface_converter/config.ini rename to embedding_converter/config.ini index 70b8d14..a040687 100644 --- a/arcface_converter/config.ini +++ b/embedding_converter/config.ini @@ -1,34 +1,31 @@ -[preparing.dataset] -dataset_path = -crop_size = -process_limit = - -[preparing.model] -source_path = -target_path = - -[preparing.input] -directory_path = -source_path = -target_path = +[training.dataset] +file_pattern = [training.loader] -split_ratio = batch_size = num_workers = +split_ratio = + +[training.model] +source_path = +target_path = [training.trainer] +learning_rate = max_epochs = +strategy = +precision = +logger_path = +logger_name = [training.output] directory_path = file_pattern = +resume_path = [exporting] directory_path = source_path = target_path = +ir_version = opset_version = - -[execution] -providers = diff --git a/arcface_converter/export.py b/embedding_converter/export.py similarity index 100% rename from arcface_converter/export.py rename to embedding_converter/export.py diff --git a/arcface_converter/src/__init__.py b/embedding_converter/src/__init__.py similarity index 100% rename from arcface_converter/src/__init__.py rename to embedding_converter/src/__init__.py diff --git a/embedding_converter/src/dataset.py b/embedding_converter/src/dataset.py new file mode 100644 index 0000000..c3b5503 --- /dev/null +++ b/embedding_converter/src/dataset.py @@ -0,0 +1,34 @@ +import glob +from configparser import ConfigParser + +from torch import Tensor +from torch.utils.data import Dataset +from torchvision import io, transforms + +from .types import Batch + + +class StaticDataset(Dataset[Tensor]): + def __init__(self, config_parser : ConfigParser) -> None: + self.config_file_pattern = config_parser.get('training.dataset', 'file_pattern') + self.file_paths = glob.glob(self.config_file_pattern) + self.transforms = self.compose_transforms() + + def __getitem__(self, index : int) -> Batch: + file_path = self.file_paths[index] + temp_tensor = io.read_image(file_path) + return self.transforms(temp_tensor) + + def __len__(self) -> int: + return len(self.file_paths) + + @staticmethod + def compose_transforms() -> transforms: + return transforms.Compose( + [ + transforms.ToPILImage(), + transforms.Resize((112, 112), interpolation = transforms.InterpolationMode.BICUBIC), + transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) diff --git a/embedding_converter/src/exporting.py b/embedding_converter/src/exporting.py new file mode 100644 index 0000000..f7e7549 --- /dev/null +++ b/embedding_converter/src/exporting.py @@ -0,0 +1,23 @@ +import os +from configparser import ConfigParser + +import torch + +from .training import EmbeddingConverterTrainer + +CONFIG_PARSER = ConfigParser() +CONFIG_PARSER.read('config.ini') + + +def export() -> None: + config_directory_path = CONFIG_PARSER.get('exporting', 'directory_path') + config_source_path = CONFIG_PARSER.get('exporting', 'source_path') + config_target_path = CONFIG_PARSER.get('exporting', 'target_path') + config_ir_version = CONFIG_PARSER.getint('exporting', 'ir_version') + config_opset_version = CONFIG_PARSER.getint('exporting', 'opset_version') + + os.makedirs(config_directory_path, exist_ok = True) + model = EmbeddingConverterTrainer.load_from_checkpoint(config_source_path, map_location = 'cpu').eval() + model.ir_version = torch.tensor(config_ir_version) + input_tensor = torch.randn(1, 512) + torch.onnx.export(model, input_tensor, config_target_path, input_names = [ 'input' ], output_names = [ 'output' ], opset_version = config_opset_version) diff --git a/embedding_converter/src/models/__init__.py b/embedding_converter/src/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/embedding_converter/src/models/embedding_converter.py b/embedding_converter/src/models/embedding_converter.py new file mode 100644 index 0000000..bc599c3 --- /dev/null +++ b/embedding_converter/src/models/embedding_converter.py @@ -0,0 +1,28 @@ +import torch +from torch import Tensor, nn + + +class EmbeddingConverter(nn.Module): + def __init__(self) -> None: + super().__init__() + self.layers = self.create_layers() + self.leaky_relu = nn.LeakyReLU() + + @staticmethod + def create_layers() -> nn.ModuleList: + return nn.ModuleList( + [ + nn.Linear(512, 1024), + nn.Linear(1024, 2048), + nn.Linear(2048, 1024), + nn.Linear(1024, 512) + ]) + + def forward(self, input_tensor : Tensor) -> Tensor: + output_tensor = input_tensor / torch.norm(input_tensor) + + for layer in self.layers[:-1]: + output_tensor = self.leaky_relu(layer(output_tensor)) + + output_tensor = self.layers[-1](output_tensor) + return output_tensor diff --git a/embedding_converter/src/training.py b/embedding_converter/src/training.py new file mode 100644 index 0000000..fbd5e39 --- /dev/null +++ b/embedding_converter/src/training.py @@ -0,0 +1,134 @@ +import os +from configparser import ConfigParser +from typing import Tuple + +import torch +from lightning import LightningModule, Trainer +from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.loggers import TensorBoardLogger +from torch import Tensor, nn +from torch.utils.data import Dataset, random_split +from torchdata.stateful_dataloader import StatefulDataLoader + +from .dataset import StaticDataset +from .models.embedding_converter import EmbeddingConverter +from .types import Batch, Embedding, OptimizerSet + +CONFIG_PARSER = ConfigParser() +CONFIG_PARSER.read('config.ini') + + +class EmbeddingConverterTrainer(LightningModule): + def __init__(self, config_parser : ConfigParser) -> None: + super().__init__() + self.config_source_path = config_parser.get('training.model', 'source_path') + self.config_target_path = config_parser.get('training.model', 'target_path') + self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate') + self.embedding_converter = EmbeddingConverter() + self.source_embedder = torch.jit.load(self.config_source_path, map_location = 'cpu').eval() + self.target_embedder = torch.jit.load(self.config_target_path, map_location = 'cpu').eval() + self.mse_loss = nn.MSELoss() + + def forward(self, source_embedding : Embedding) -> Embedding: + return self.embedding_converter(source_embedding) + + def training_step(self, batch : Batch, batch_index : int) -> Tensor: + with torch.no_grad(): + source_embedding = self.source_embedder(batch) + target_embedding = self.target_embedder(batch) + output_embedding = self(source_embedding) + training_loss = self.mse_loss(output_embedding, target_embedding) + self.log('training_loss', training_loss, prog_bar = True) + return training_loss + + def validation_step(self, batch : Batch, batch_index : int) -> Tensor: + with torch.no_grad(): + source_embedding = self.source_embedder(batch) + output_embedding = self(source_embedding) + validation_score = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5 + self.log('validation_score', validation_score, sync_dist = True, prog_bar = True) + return validation_score + + def configure_optimizers(self) -> OptimizerSet: + optimizer = torch.optim.Adam(self.parameters(), lr = self.config_learning_rate) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) + optimizer_set =\ + { + 'optimizer': optimizer, + 'lr_scheduler': + { + 'scheduler': scheduler, + 'monitor': 'training_loss', + 'interval': 'epoch', + 'frequency': 1 + } + } + + return optimizer_set + + +def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]: + config_batch_size = CONFIG_PARSER.getint('training.loader', 'batch_size') + config_num_workers = CONFIG_PARSER.getint('training.loader', 'num_workers') + + training_dataset, validate_dataset = split_dataset(dataset) + training_loader = StatefulDataLoader(training_dataset, batch_size = config_batch_size, shuffle = True, num_workers = config_num_workers, drop_last = True, pin_memory = True, persistent_workers = True) + validation_loader = StatefulDataLoader(validate_dataset, batch_size = config_batch_size, shuffle = False, num_workers = config_num_workers, pin_memory = True, persistent_workers = True) + return training_loader, validation_loader + + +def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[Tensor]]: + config_split_ratio = CONFIG_PARSER.getfloat('training.loader', 'split_ratio') + + dataset_size = len(dataset) # type:ignore[arg-type] + training_size = int(dataset_size * config_split_ratio) + validation_size = int(dataset_size - training_size) + training_dataset, validate_dataset = random_split(dataset, [ training_size, validation_size ]) + return training_dataset, validate_dataset + + +def create_trainer() -> Trainer: + config_max_epochs = CONFIG_PARSER.getint('training.trainer', 'max_epochs') + config_strategy = CONFIG_PARSER.get('training.trainer', 'strategy') + config_precision = CONFIG_PARSER.get('training.trainer', 'precision') + config_logger_path = CONFIG_PARSER.get('training.trainer', 'logger_path') + config_logger_name = CONFIG_PARSER.get('training.trainer', 'logger_name') + config_directory_path = CONFIG_PARSER.get('training.output', 'directory_path') + config_file_pattern = CONFIG_PARSER.get('training.output', 'file_pattern') + logger = TensorBoardLogger(config_logger_path, config_logger_name) + + return Trainer( + logger = logger, + log_every_n_steps = 10, + max_epochs = config_max_epochs, + strategy = config_strategy, + precision = config_precision, + callbacks = + [ + ModelCheckpoint( + monitor = 'training_loss', + dirpath = config_directory_path, + filename = config_file_pattern, + every_n_epochs = 1, + save_top_k = 3, + save_last = True + ) + ] + ) + + +def train() -> None: + config_resume_path = CONFIG_PARSER.get('training.output', 'resume_path') + + if torch.cuda.is_available(): + torch.set_float32_matmul_precision('high') + + dataset = StaticDataset(CONFIG_PARSER) + training_loader, validation_loader = create_loaders(dataset) + embedding_converter_trainer = EmbeddingConverterTrainer(CONFIG_PARSER) + trainer = create_trainer() + + if os.path.exists(config_resume_path): + trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config_resume_path) + else: + trainer.fit(embedding_converter_trainer, training_loader, validation_loader) diff --git a/embedding_converter/src/types.py b/embedding_converter/src/types.py new file mode 100644 index 0000000..0522b39 --- /dev/null +++ b/embedding_converter/src/types.py @@ -0,0 +1,8 @@ +from typing import Any, TypeAlias + +from torch import Tensor + +Batch : TypeAlias = Tensor +Embedding : TypeAlias = Tensor + +OptimizerSet : TypeAlias = Any diff --git a/arcface_converter/train.py b/embedding_converter/train.py similarity index 100% rename from arcface_converter/train.py rename to embedding_converter/train.py diff --git a/face_swapper/LICENSE.md b/face_swapper/LICENSE.md new file mode 100644 index 0000000..011a811 --- /dev/null +++ b/face_swapper/LICENSE.md @@ -0,0 +1,3 @@ +ResearchRAIL-MS license + +Copyright (c) 2025 Henry Ruhs diff --git a/face_swapper/README.md b/face_swapper/README.md new file mode 100644 index 0000000..a6c65a4 --- /dev/null +++ b/face_swapper/README.md @@ -0,0 +1,160 @@ +Face Swapper +============ + +> Face shape and occlusion aware identity transfer. + +![License](https://img.shields.io/badge/license-ResearchRAIL--MS-red) + + +Preview +------- + +![Preview](https://raw.githubusercontent.com/facefusion/facefusion-labs/next/.github/previews/face_swapper.png?sanitize=true) + + +Installation +------------ + +``` +pip install -r requirements.txt +``` + + +Setup +----- + +This `config.ini` utilizes the MegaFace dataset to train the Face Swapper model. + +``` +[training.dataset] +file_pattern = .datasets/vggface2/**/*.jpg +warp_template = vgg_face_hq_to_arcface_128_v2 +transform_size = 256 +batch_mode = equal +batch_ratio = 0.2 +``` + +``` +[training.loader] +batch_size = 8 +num_workers = 8 +split_ratio = 0.9995 +``` + +``` +[training.model] +generator_embedder_path = .models/blendface.pt +loss_embedder_path = .models/adaface.pt +gazer_path = .models/gazer.pt +face_masker_path = .models/face_masker.pt +``` + +``` +[training.model.generator] +source_channels = 512 +output_channels = 4096 +output_size = 256 +num_blocks = 2 +``` + +``` +[training.model.discriminator] +input_channels = 3 +num_filters = 64 +num_layers = 5 +num_discriminators = 3 +kernel_size = 4 +``` + +``` +[training.model.masker] +input_channels = 67 +output_channels = 1 +num_filters = 16 +``` + +``` +[training.losses] +adversarial_weight = 1.0 +cycle_weight = 1.0 +feature_weight = 10.0 +reconstruction_weight = 10.0 +identity_weight = 20.0 +gaze_weight = 0.05 +mask_weight = 5.0 +``` + +``` +[training.trainer] +accumulate_size = 4 +learning_rate = 0.0004 +max_epochs = 50 +strategy = auto +precision = 16-mixed +logger_path = .logs +logger_name = face_swapper +preview_frequency = 100 +``` + +``` +[training.output] +directory_path = .outputs +file_pattern = face_swapper_{epoch}_{step} +resume_path = .outputs/last.ckpt +``` + +``` +[exporting] +directory_path = .exports +source_path = .outputs/last.ckpt +target_path = .exports/face_swapper.onnx +target_size = 256 +ir_version = 10 +opset_version = 15 +precision = full +``` + +``` +[inferencing] +generator_path = .outputs/last.ckpt +embedder_path = .models/arcface.pt +source_path = .assets/source.jpg +target_path = .assets/target.jpg +output_path = .outputs/output.jpg +``` + + +Training +-------- + +Train the Face Swapper model. + +``` +python train.py +``` + +Launch the TensorBoard to monitor the training. + +``` +tensorboard --logdir=.logs +``` + + +Exporting +--------- + +Export the model to ONNX. + +``` +python export.py +``` + + +Inferencing +----------- + +Inference the model. + +``` +python infer.py +``` diff --git a/face_swapper/__init__.py b/face_swapper/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/face_swapper/config.ini b/face_swapper/config.ini new file mode 100644 index 0000000..fd94c01 --- /dev/null +++ b/face_swapper/config.ini @@ -0,0 +1,75 @@ +[training.dataset] +file_pattern = +warp_template = +transform_size = +batch_mode = +batch_ratio = + +[training.loader] +batch_size = +num_workers = +split_ratio = + +[training.model] +generator_embedder_path = +loss_embedder_path = +gazer_path = +face_masker_path = + +[training.model.generator] +source_channels = +output_channels = +output_size = +num_blocks = + +[training.model.discriminator] +input_channels = +num_filters = +num_layers = +num_discriminators = +kernel_size = + +[training.model.masker] +input_channels = +output_channels = +num_filters = + +[training.losses] +adversarial_weight = +cycle_weight = +feature_weight = +reconstruction_weight = +identity_weight = +gaze_weight = +mask_weight = + +[training.trainer] +accumulate_size = +learning_rate = +max_epochs = +strategy = +precision = +logger_path = +logger_name = +preview_frequency = + +[training.output] +directory_path = +file_pattern = +resume_path = + +[exporting] +directory_path = +source_path = +target_path = +target_size = +ir_version = +opset_version = +precision = + +[inferencing] +generator_path = +embedder_path = +source_path = +target_path = +output_path = diff --git a/arcface_converter/prepare.py b/face_swapper/export.py similarity index 53% rename from arcface_converter/prepare.py rename to face_swapper/export.py index 4cf7306..d560240 100644 --- a/arcface_converter/prepare.py +++ b/face_swapper/export.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -from src.preparing import prepare +from src.exporting import export if __name__ == '__main__': - prepare() + export() diff --git a/face_swapper/infer.py b/face_swapper/infer.py new file mode 100644 index 0000000..dde2e9f --- /dev/null +++ b/face_swapper/infer.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from src.inferencing import infer + +if __name__ == '__main__': + infer() diff --git a/face_swapper/src/__init__.py b/face_swapper/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/face_swapper/src/dataset.py b/face_swapper/src/dataset.py new file mode 100644 index 0000000..db04932 --- /dev/null +++ b/face_swapper/src/dataset.py @@ -0,0 +1,107 @@ +import glob +import os +import random +from configparser import ConfigParser +from typing import cast + +import albumentations +from torch import Tensor +from torch.utils.data import Dataset +from torchvision import io, transforms + +from .helper import warp_tensor +from .types import Batch, BatchMode, WarpTemplate + + +class DynamicDataset(Dataset[Tensor]): + def __init__(self, config_parser : ConfigParser) -> None: + self.config_file_pattern = config_parser.get('training.dataset', 'file_pattern') + self.config_transform_size = config_parser.getint('training.dataset', 'transform_size') + self.config_batch_mode = cast(BatchMode, config_parser.get('training.dataset', 'batch_mode')) + self.config_batch_ratio = config_parser.getfloat('training.dataset', 'batch_ratio') + self.config_parser = config_parser + self.file_paths = glob.glob(self.config_file_pattern) + self.transforms = self.compose_transforms() + + def __getitem__(self, index : int) -> Batch: + file_path = self.file_paths[index] + + if random.random() < self.config_batch_ratio: + if self.config_batch_mode == 'equal': + return self.prepare_equal_batch(file_path) + if self.config_batch_mode == 'same': + return self.prepare_same_batch(file_path) + + return self.prepare_different_batch(file_path) + + def __len__(self) -> int: + return len(self.file_paths) + + def compose_transforms(self) -> transforms: + return transforms.Compose( + [ + AugmentTransform(), + transforms.ToPILImage(), + transforms.Resize((self.config_transform_size, self.config_transform_size), interpolation = transforms.InterpolationMode.BICUBIC), + transforms.ToTensor(), + WarpTransform(self.config_parser), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) + + def prepare_different_batch(self, source_path : str) -> Batch: + target_path = random.choice(self.file_paths) + source_tensor = io.read_image(source_path) + source_tensor = self.transforms(source_tensor) + target_tensor = io.read_image(target_path) + target_tensor = self.transforms(target_tensor) + return source_tensor, target_tensor + + def prepare_equal_batch(self, source_path : str) -> Batch: + source_tensor = io.read_image(source_path) + source_tensor = self.transforms(source_tensor) + return source_tensor, source_tensor + + def prepare_same_batch(self, source_path : str) -> Batch: + target_directory_path = os.path.dirname(source_path) + target_file_name_and_extension = random.choice(os.listdir(target_directory_path)) + target_path = os.path.join(target_directory_path, target_file_name_and_extension) + source_tensor = io.read_image(source_path) + source_tensor = self.transforms(source_tensor) + target_tensor = io.read_image(target_path) + target_tensor = self.transforms(target_tensor) + return source_tensor, target_tensor + + +class AugmentTransform: + def __init__(self) -> None: + self.transforms = self.compose_transforms() + + def __call__(self, input_tensor : Tensor) -> Tensor: + temp_tensor = input_tensor.numpy().transpose(1, 2, 0) + return self.transforms(image = temp_tensor).get('image') + + @staticmethod + def compose_transforms() -> albumentations.Compose: + return albumentations.Compose( + [ + albumentations.HorizontalFlip(), + albumentations.OneOf( + [ + albumentations.MotionBlur(p = 0.1), + albumentations.ZoomBlur(max_factor = (1.0, 1.1), p = 0.1) + ], p = 0.2), + albumentations.RandomBrightnessContrast(p = 0.7), + albumentations.ColorJitter(p = 0.2), + albumentations.RGBShift(p = 0.7), + albumentations.Illumination(p = 0.2), + albumentations.Affine(translate_percent = (-0.03, 0.03), scale = (0.98, 1.02), rotate = (-2, 2), border_mode = 1, p = 0.3) + ]) + + +class WarpTransform: + def __init__(self, config_parser : ConfigParser) -> None: + self.config_warp_template = cast(WarpTemplate, config_parser.get('training.dataset', 'warp_template')) + + def __call__(self, input_tensor : Tensor) -> Tensor: + temp_tensor = input_tensor.unsqueeze(0) + return warp_tensor(temp_tensor, self.config_warp_template).squeeze(0) diff --git a/face_swapper/src/exporting.py b/face_swapper/src/exporting.py new file mode 100644 index 0000000..34916a0 --- /dev/null +++ b/face_swapper/src/exporting.py @@ -0,0 +1,47 @@ +import os +from configparser import ConfigParser +from typing import Tuple + +import torch +from torch import Tensor, nn + +from .training import FaceSwapperTrainer +from .types import Embedding, Mask, Module + +CONFIG_PARSER = ConfigParser() +CONFIG_PARSER.read('config.ini') + + +class HalfPrecision(nn.Module): + def __init__(self, model : Module) -> None: + super().__init__() + self.model = model.half() + + def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Mask]: + source_embedding = source_embedding.half() + target_tensor = target_tensor.half() + output_tensor, output_mask = self.model(source_embedding, target_tensor) + output_tensor = output_tensor.float() + output_mask = output_mask.float() + return output_tensor, output_mask + + +def export() -> None: + config_directory_path = CONFIG_PARSER.get('exporting', 'directory_path') + config_source_path = CONFIG_PARSER.get('exporting', 'source_path') + config_target_path = CONFIG_PARSER.get('exporting', 'target_path') + config_target_size = CONFIG_PARSER.getint('exporting', 'target_size') + config_ir_version = CONFIG_PARSER.getint('exporting', 'ir_version') + config_opset_version = CONFIG_PARSER.getint('exporting', 'opset_version') + config_precision = CONFIG_PARSER.get('exporting', 'precision') + + os.makedirs(config_directory_path, exist_ok = True) + model = FaceSwapperTrainer.load_from_checkpoint(config_source_path, config_parser = CONFIG_PARSER, map_location = 'cpu').eval() + + if config_precision == 'half': + model = HalfPrecision(model).eval() + + model.ir_version = torch.tensor(config_ir_version) + source_tensor = torch.randn(1, 512) + target_tensor = torch.randn(1, 3, config_target_size, config_target_size) + torch.onnx.export(model, (source_tensor, target_tensor), config_target_path, input_names = [ 'source', 'target' ], output_names = [ 'output', 'mask' ], opset_version = config_opset_version) diff --git a/face_swapper/src/helper.py b/face_swapper/src/helper.py new file mode 100644 index 0000000..ff9d4d2 --- /dev/null +++ b/face_swapper/src/helper.py @@ -0,0 +1,51 @@ +import torch +from torch import Tensor, nn + +from .types import EmbedderModule, Embedding, Mask, Padding, WarpTemplate, WarpTemplateSet + +WARP_TEMPLATE_SET : WarpTemplateSet =\ +{ + 'arcface_128_v2_to_arcface_112_v2': torch.tensor( + [ + [ 8.75000016e-01, -1.07193451e-08, 3.80446920e-10 ], + [ 1.07193451e-08, 8.75000016e-01, -1.25000007e-01 ] + ]), + 'ffhq_to_arcface_128_v2': torch.tensor( + [ + [ 8.50048894e-01, -1.29486822e-04, 1.90956388e-03 ], + [ 1.29486822e-04, 8.50048894e-01, 9.56254653e-02 ] + ]), + 'vgg_face_hq_to_arcface_128_v2': torch.tensor( + [ + [ 1.01305414, -0.00140513, -0.00585911 ], + [ 0.00140513, 1.01305414, 0.11169602 ] + ]) +} + + +def warp_tensor(input_tensor : Tensor, warp_template : WarpTemplate) -> Tensor: + normed_warp_template = WARP_TEMPLATE_SET.get(warp_template).repeat(input_tensor.shape[0], 1, 1) + affine_grid = nn.functional.affine_grid(normed_warp_template.to(input_tensor.device), list(input_tensor.shape)) + output_tensor = nn.functional.grid_sample(input_tensor, affine_grid, align_corners = False, padding_mode = 'reflection') + return output_tensor + + +def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : Padding) -> Embedding: + crop_tensor = warp_tensor(input_tensor, 'arcface_128_v2_to_arcface_112_v2') + crop_tensor = nn.functional.interpolate(crop_tensor, size = 112, mode = 'area') + crop_tensor[:, :, :padding[0], :] = 0 + crop_tensor[:, :, 112 - padding[1]:, :] = 0 + crop_tensor[:, :, :, :padding[2]] = 0 + crop_tensor[:, :, :, 112 - padding[3]:] = 0 + + embedding = embedder(crop_tensor) + embedding = nn.functional.normalize(embedding, p = 2) + return embedding + + +def overlay_mask(input_tensor : Tensor, input_mask : Mask) -> Tensor: + overlay_tensor = torch.zeros(*input_tensor.shape, dtype = input_tensor.dtype, device = input_tensor.device) + overlay_tensor[:, 2, :, :] = 1 + input_mask = input_mask.repeat(1, 3, 1, 1).clamp(0, 0.8) + output_tensor = input_tensor * (1 - input_mask) + overlay_tensor * input_mask + return output_tensor diff --git a/face_swapper/src/inferencing.py b/face_swapper/src/inferencing.py new file mode 100644 index 0000000..6f81b2d --- /dev/null +++ b/face_swapper/src/inferencing.py @@ -0,0 +1,27 @@ +import configparser + +import torch +from torchvision import io + +from .helper import calc_embedding +from .training import FaceSwapperTrainer + +CONFIG_PARSER = configparser.ConfigParser() +CONFIG_PARSER.read('config.ini') + + +def infer() -> None: + config_generator_path = CONFIG_PARSER.get('inferencing', 'generator_path') + config_embedder_path = CONFIG_PARSER.get('inferencing', 'embedder_path') + config_source_path = CONFIG_PARSER.get('inferencing', 'source_path') + config_target_path = CONFIG_PARSER.get('inferencing', 'target_path') + config_output_path = CONFIG_PARSER.get('inferencing', 'output_path') + + generator = FaceSwapperTrainer.load_from_checkpoint(config_generator_path, config_parser = CONFIG_PARSER, map_location = 'cpu').eval() + embedder = torch.jit.load(config_embedder_path, map_location = 'cpu').eval() + + source_tensor = io.read_image(config_source_path) + target_tensor = io.read_image(config_target_path) + source_embedding = calc_embedding(embedder, source_tensor, (0, 0, 0, 0)) + output_tensor, _ = generator(source_embedding, target_tensor) + io.write_jpeg(output_tensor, config_output_path) diff --git a/face_swapper/src/models/__init__.py b/face_swapper/src/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/face_swapper/src/models/discriminator.py b/face_swapper/src/models/discriminator.py new file mode 100644 index 0000000..b44305d --- /dev/null +++ b/face_swapper/src/models/discriminator.py @@ -0,0 +1,35 @@ +from configparser import ConfigParser +from typing import List + +from torch import Tensor, nn + +from ..networks.nld import NLD + + +class Discriminator(nn.Module): + def __init__(self, config_parser : ConfigParser) -> None: + super().__init__() + self.config_num_discriminators = config_parser.getint('training.model.discriminator', 'num_discriminators') + self.config_parser = config_parser + self.discriminators = self.create_discriminators() + self.avg_pool = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = (1, 1), count_include_pad = False) + + def create_discriminators(self) -> nn.ModuleList: + discriminators = nn.ModuleList() + + for _ in range(self.config_num_discriminators): + discriminator = NLD(self.config_parser).sequences + discriminators.append(discriminator) + + return discriminators + + def forward(self, input_tensor : Tensor) -> List[Tensor]: + temp_tensor = input_tensor + output_tensors = [] + + for discriminator in self.discriminators: + output_tensor = discriminator(temp_tensor) + output_tensors.append(output_tensor) + temp_tensor = self.avg_pool(temp_tensor) + + return output_tensors diff --git a/face_swapper/src/models/generator.py b/face_swapper/src/models/generator.py new file mode 100644 index 0000000..ab76154 --- /dev/null +++ b/face_swapper/src/models/generator.py @@ -0,0 +1,42 @@ +from configparser import ConfigParser +from typing import Tuple + +from torch import Tensor, nn + +from ..networks.aad import AAD +from ..networks.masknet import MaskNet +from ..networks.unet import UNet +from ..types import Embedding, Feature, Mask + + +class Generator(nn.Module): + def __init__(self, config_parser : ConfigParser) -> None: + super().__init__() + self.encoder = UNet(config_parser) + self.generator = AAD(config_parser) + self.masker = MaskNet(config_parser) + self.encoder.apply(init_weight) + self.generator.apply(init_weight) + self.masker.apply(init_weight) + + def forward(self, source_embedding : Embedding, target_tensor : Tensor, target_features : Tuple[Feature, ...]) -> Tuple[Tensor, Mask]: + output_tensor = self.generator(source_embedding, target_features) + target_feature = target_features[-1] + output_mask = self.masker(target_tensor, target_feature) + output_tensor = output_tensor * output_mask + target_tensor * (1 - output_mask) + return output_tensor, output_mask + + def encode_features(self, input_tensor : Tensor) -> Tuple[Feature, ...]: + return self.encoder(input_tensor) + + +def init_weight(module : nn.Module) -> None: + if isinstance(module, nn.Linear): + module.weight.data.normal_(std = 0.001) + module.bias.data.zero_() + + if isinstance(module, nn.Conv2d): + nn.init.xavier_normal_(module.weight.data) + + if isinstance(module, nn.ConvTranspose2d): + nn.init.xavier_normal_(module.weight.data) diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py new file mode 100644 index 0000000..dbe8e10 --- /dev/null +++ b/face_swapper/src/models/loss.py @@ -0,0 +1,186 @@ +from configparser import ConfigParser +from typing import List, Tuple + +import torch +from pytorch_msssim import ssim +from torch import Tensor, nn +from torchvision import transforms + +from ..helper import calc_embedding +from ..types import EmbedderModule, FaceMaskerModule, Feature, GazerModule, Loss, Mask + + +class DiscriminatorLoss(nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, discriminator_source_tensors : List[Tensor], discriminator_output_tensors : List[Tensor]) -> Loss: + positive_tensors = [] + negative_tensors = [] + + for discriminator_source_tensor in discriminator_source_tensors: + positive_tensor = torch.relu(1 - discriminator_source_tensor).mean(dim = [ 1, 2, 3 ]) + positive_tensors.append(positive_tensor) + + for discriminator_output_tensor in discriminator_output_tensors: + negative_tensor = torch.relu(discriminator_output_tensor + 1).mean(dim = [ 1, 2, 3 ]) + negative_tensors.append(negative_tensor) + + positive_loss = torch.stack(positive_tensors).mean() + negative_loss = torch.stack(negative_tensors).mean() + discriminator_loss = (positive_loss + negative_loss) * 0.5 + return discriminator_loss + + +class AdversarialLoss(nn.Module): + def __init__(self, config_parser : ConfigParser) -> None: + super().__init__() + self.config_adversarial_weight = config_parser.getfloat('training.losses', 'adversarial_weight') + + def forward(self, discriminator_output_tensors : List[Tensor]) -> Tuple[Loss, Loss]: + temp_tensors = [] + + for discriminator_output_tensor in discriminator_output_tensors: + temp_tensor = torch.relu(1 - discriminator_output_tensor).mean(dim = [ 1, 2, 3 ]).mean() + temp_tensors.append(temp_tensor) + + adversarial_loss = torch.stack(temp_tensors).mean() + weighted_adversarial_loss = adversarial_loss * self.config_adversarial_weight + return adversarial_loss, weighted_adversarial_loss + + +class CycleLoss(nn.Module): + def __init__(self, config_parser : ConfigParser) -> None: + super().__init__() + self.config_batch_size = config_parser.getint('training.loader', 'batch_size') + self.config_cycle_weight = config_parser.getfloat('training.losses', 'cycle_weight') + self.l1_loss = nn.L1Loss() + + def forward(self, target_tensor : Tensor, cycle_tensor : Tensor, target_features : Tuple[Feature, ...], cycle_features : Tuple[Feature, ...]) -> Tuple[Loss, Loss]: + temp_tensors = [] + + for target_feature, output_feature in zip(target_features, cycle_features): + temp_tensor = torch.mean(torch.pow(output_feature - target_feature, 2).reshape(self.config_batch_size, -1), dim = 1).mean() + temp_tensors.append(temp_tensor) + + feature_loss = torch.stack(temp_tensors).mean() + reconstruction_loss = self.l1_loss(target_tensor, cycle_tensor) + cycle_loss = (feature_loss + reconstruction_loss) * 0.5 + weighted_feature_loss = cycle_loss * self.config_cycle_weight + return cycle_loss, weighted_feature_loss + + +class FeatureLoss(nn.Module): + def __init__(self, config_parser : ConfigParser) -> None: + super().__init__() + self.config_batch_size = config_parser.getint('training.loader', 'batch_size') + self.config_feature_weight = config_parser.getfloat('training.losses', 'feature_weight') + + def forward(self, target_features : Tuple[Feature, ...], output_features : Tuple[Feature, ...]) -> Tuple[Loss, Loss]: + temp_tensors = [] + + for target_feature, output_feature in zip(target_features, output_features): + temp_tensor = torch.mean(torch.pow(output_feature - target_feature, 2).reshape(self.config_batch_size, -1), dim = 1).mean() + temp_tensors.append(temp_tensor) + + feature_loss = torch.stack(temp_tensors).mean() * 0.5 + weighted_feature_loss = feature_loss * self.config_feature_weight + return feature_loss, weighted_feature_loss + + +class ReconstructionLoss(nn.Module): + def __init__(self, config_parser : ConfigParser, embedder : EmbedderModule) -> None: + super().__init__() + self.config_reconstruction_weight = config_parser.getfloat('training.losses', 'reconstruction_weight') + self.embedder = embedder + self.mse_loss = nn.MSELoss() + + def forward(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss]: + with torch.no_grad(): + source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0)) + target_embedding = calc_embedding(self.embedder, target_tensor, (0, 0, 0, 0)) + + has_similar_identity = torch.cosine_similarity(source_embedding, target_embedding) > 0.8 + + reconstruction_loss = torch.mean((source_tensor - target_tensor) ** 2, dim = (1, 2, 3)) + reconstruction_loss = (reconstruction_loss * has_similar_identity).mean() * 0.5 + + data_range = float(torch.max(output_tensor) - torch.min(output_tensor)) + visual_loss = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean() + reconstruction_loss = (reconstruction_loss + visual_loss) * 0.5 + weighted_reconstruction_loss = reconstruction_loss * self.config_reconstruction_weight + return reconstruction_loss, weighted_reconstruction_loss + + +class IdentityLoss(nn.Module): + def __init__(self, config_parser : ConfigParser, embedder : EmbedderModule) -> None: + super().__init__() + self.config_identity_weight = config_parser.getfloat('training.losses', 'identity_weight') + self.embedder = embedder + + def forward(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss]: + output_embedding = calc_embedding(self.embedder, output_tensor, (30, 0, 10, 10)) + source_embedding = calc_embedding(self.embedder, source_tensor, (30, 0, 10, 10)) + identity_loss = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean() + weighted_identity_loss = identity_loss * self.config_identity_weight + return identity_loss, weighted_identity_loss + + +class GazeLoss(nn.Module): + def __init__(self, config_parser : ConfigParser, gazer : GazerModule) -> None: + super().__init__() + self.config_gaze_weight = config_parser.getfloat('training.losses', 'gaze_weight') + self.config_output_size = config_parser.getint('training.model.generator', 'output_size') + self.gazer = gazer + self.l1_loss = nn.L1Loss() + + def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss]: + output_pitch, output_yaw = self.detect_gaze(output_tensor) + target_pitch, target_yaw = self.detect_gaze(target_tensor) + + pitch_loss = self.l1_loss(output_pitch, target_pitch) + yaw_loss = self.l1_loss(output_yaw, target_yaw) + + gaze_loss = (pitch_loss + yaw_loss) * 0.5 + weighted_gaze_loss = gaze_loss * self.config_gaze_weight + return gaze_loss, weighted_gaze_loss + + def detect_gaze(self, input_tensor : Tensor) -> Tuple[Tensor, Tensor]: + crop_sizes = (torch.tensor([ 0.235, 0.875, 0.0625, 0.8 ]) * self.config_output_size).int() + crop_tensor = input_tensor[:, :, crop_sizes[0]:crop_sizes[1], crop_sizes[2]:crop_sizes[3]] + crop_tensor = (crop_tensor + 1) * 0.5 + crop_tensor = transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ])(crop_tensor) + crop_tensor = nn.functional.interpolate(crop_tensor, size = 448, mode = 'bicubic') + + with torch.no_grad(): + pitch, yaw = self.gazer(crop_tensor) + + return pitch, yaw + + +class MaskLoss(nn.Module): + def __init__(self, config_parser : ConfigParser, face_masker : FaceMaskerModule) -> None: + super().__init__() + self.config_mask_weight = config_parser.getfloat('training.losses', 'mask_weight') + self.config_output_size = config_parser.getint('training.model.generator', 'output_size') + self.face_masker = face_masker + self.mse_loss = nn.MSELoss() + + def forward(self, target_tensor : Tensor, output_mask : Mask) -> Tuple[Loss, Loss]: + target_mask = self.calc_mask(target_tensor) + target_mask = target_mask.view(-1, self.config_output_size, self.config_output_size) + output_mask = output_mask.view(-1, self.config_output_size, self.config_output_size) + mask_loss = self.mse_loss(target_mask, output_mask) + weighted_mask_loss = mask_loss * self.config_mask_weight + return mask_loss, weighted_mask_loss + + def calc_mask(self, target_tensor : Tensor) -> Tensor: + target_tensor = torch.nn.functional.interpolate(target_tensor, (256, 256), mode = 'bilinear') + target_tensor = (target_tensor.clip(-1, 1) + 1) * 0.5 + + with torch.no_grad(): + output_tensor = self.face_masker(target_tensor) + output_tensor = output_tensor.clamp(0, 1) + output_tensor = torch.nn.functional.interpolate(output_tensor, (self.config_output_size, self.config_output_size), mode = 'bilinear') + + return output_tensor diff --git a/face_swapper/src/networks/__init__.py b/face_swapper/src/networks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/face_swapper/src/networks/aad.py b/face_swapper/src/networks/aad.py new file mode 100644 index 0000000..10c1231 --- /dev/null +++ b/face_swapper/src/networks/aad.py @@ -0,0 +1,191 @@ +from configparser import ConfigParser +from typing import Tuple + +import torch +from torch import Tensor, nn + +from ..types import Embedding, Feature + + +class AAD(nn.Module): + def __init__(self, config_parser : ConfigParser) -> None: + super().__init__() + self.config_source_channels = config_parser.getint('training.model.generator', 'source_channels') + self.config_output_channels = config_parser.getint('training.model.generator', 'output_channels') + self.config_output_size = config_parser.getint('training.model.generator', 'output_size') + self.config_num_blocks = config_parser.getint('training.model.generator', 'num_blocks') + self.pixel_shuffle_up_sample = PixelShuffleUpSample(self.config_source_channels, self.config_output_channels) + self.layers = self.create_layers() + + def create_layers(self) -> nn.ModuleList: + layers = nn.ModuleList() + + if self.config_output_size == 128: + layers.extend( + [ + AdaptiveFeatureModulation(512, 512, 512, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(512, 512, 1024, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(512, 512, 512, self.config_source_channels, self.config_num_blocks) + ]) + + if self.config_output_size == 256: + layers.extend( + [ + AdaptiveFeatureModulation(1024, 1024, 1024, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(1024, 1024, 2048, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(1024, 1024, 1024, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks) + ]) + + if self.config_output_size == 512: + layers.extend( + [ + AdaptiveFeatureModulation(2048, 2048, 2048, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(2048, 2048, 4096, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(2048, 2048, 2048, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(2048, 1024, 1024, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks) + ]) + + if self.config_output_size == 1024: + layers.extend( + [ + AdaptiveFeatureModulation(4096, 4096, 4096, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(4096, 4096, 8192, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(4096, 4096, 4096, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(4096, 2048, 2048, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(2048, 1024, 1024, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks) + ]) + + layers.extend( + [ + AdaptiveFeatureModulation(512, 256, 256, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(256, 128, 128, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(128, 64, 64, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(64, 3, 64, self.config_source_channels, self.config_num_blocks) + ]) + + return layers + + def forward(self, source_embedding : Embedding, target_features : Tuple[Feature, ...]) -> Tensor: + temp_tensors = self.pixel_shuffle_up_sample(source_embedding) + + for index, layer in enumerate(self.layers[:-1]): + target_feature = target_features[index] + temp_tensor = layer(temp_tensors, source_embedding, target_feature) + temp_tensors = nn.functional.interpolate(temp_tensor, scale_factor = 2, mode = 'bilinear', align_corners = False) + + target_feature = target_features[-1] + temp_tensors = self.layers[-1](temp_tensors, source_embedding, target_feature) + output_tensor = torch.tanh(temp_tensors) + return output_tensor + + +class AdaptiveFeatureModulation(nn.Module): + def __init__(self, input_channels : int, output_channels : int, target_channels : int, source_channels : int, num_blocks : int) -> None: + super().__init__() + self.context_input_channels = input_channels + self.context_output_channels = output_channels + self.context_target_channels = target_channels + self.context_source_channels = source_channels + self.context_num_blocks = num_blocks + self.primary_layers = self.create_primary_layers() + self.shortcut_layers = self.create_shortcut_layers() + + def create_primary_layers(self) -> nn.ModuleList: + primary_layers = nn.ModuleList() + + for index in range(self.context_num_blocks): + primary_layers.extend( + [ + FeatureModulation(self.context_input_channels, self.context_target_channels, self.context_source_channels), + nn.ReLU(inplace = True) + ]) + + if index < self.context_num_blocks - 1: + primary_layers.append(nn.Conv2d(self.context_input_channels, self.context_input_channels, kernel_size = 3, padding = 1, bias = False)) + else: + primary_layers.append(nn.Conv2d(self.context_input_channels, self.context_output_channels, kernel_size = 3, padding = 1, bias = False)) + + return primary_layers + + def create_shortcut_layers(self) -> nn.ModuleList: + shortcut_layers = nn.ModuleList() + + if self.context_input_channels > self.context_output_channels: + shortcut_layers.extend( + [ + FeatureModulation(self.context_input_channels, self.context_target_channels, self.context_source_channels), + nn.ReLU(inplace = True), + nn.Conv2d(self.context_input_channels, self.context_output_channels, kernel_size = 3, padding = 1, bias = False) + ]) + + return shortcut_layers + + def forward(self, input_tensor : Tensor, source_embedding : Embedding, target_feature : Feature) -> Tensor: + primary_tensor = input_tensor + + for primary_layer in self.primary_layers: + if isinstance(primary_layer, FeatureModulation): + primary_tensor = primary_layer(primary_tensor, source_embedding, target_feature) + else: + primary_tensor = primary_layer(primary_tensor) + + if self.context_input_channels > self.context_output_channels: + shortcut_tensor = input_tensor + + for shortcut_layer in self.shortcut_layers: + if isinstance(shortcut_layer, FeatureModulation): + shortcut_tensor = shortcut_layer(shortcut_tensor, source_embedding, target_feature) + else: + shortcut_tensor = shortcut_layer(shortcut_tensor) + + input_tensor = shortcut_tensor + + return primary_tensor + input_tensor + + +class FeatureModulation(nn.Module): + def __init__(self, input_channels : int, target_channels : int, source_channels : int) -> None: + super().__init__() + self.context_input_channels = input_channels + self.conv1 = nn.Conv2d(target_channels, input_channels, kernel_size = 1) + self.conv2 = nn.Conv2d(target_channels, input_channels, kernel_size = 1) + self.conv3 = nn.Conv2d(input_channels, 1, kernel_size = 1) + self.linear1 = nn.Linear(source_channels, input_channels) + self.linear2 = nn.Linear(source_channels, input_channels) + self.instance_norm = nn.InstanceNorm2d(input_channels) + + def forward(self, input_tensor : Tensor, source_embedding : Embedding, target_feature : Feature) -> Tensor: + temp_tensor = self.instance_norm(input_tensor) + + source_scale = self.linear2(source_embedding).reshape(temp_tensor.shape[0], self.context_input_channels, 1, 1).expand_as(temp_tensor) + source_shift = self.linear1(source_embedding).reshape(temp_tensor.shape[0], self.context_input_channels, 1, 1).expand_as(temp_tensor) + source_modulation = source_scale * temp_tensor + source_shift + + target_scale = self.conv1(target_feature) + target_shift = self.conv2(target_feature) + target_modulation = target_scale * temp_tensor + target_shift + + temp_mask = torch.sigmoid(self.conv3(temp_tensor)) + output_tensor = (1 - temp_mask) * target_modulation + temp_mask * source_modulation + return output_tensor + + +class PixelShuffleUpSample(nn.Module): + def __init__(self, input_channels : int, output_channels : int) -> None: + super().__init__() + self.sequences = self.create_sequences(input_channels, output_channels) + + @staticmethod + def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential: + return nn.Sequential( + nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1), + nn.PixelShuffle(upscale_factor = 2) + ) + + def forward(self, input_tensor : Tensor) -> Tensor: + temp_tensor = input_tensor.view(input_tensor.shape[0], -1, 1, 1) + output_tensor = self.sequences(temp_tensor) + return output_tensor diff --git a/face_swapper/src/networks/masknet.py b/face_swapper/src/networks/masknet.py new file mode 100644 index 0000000..023b767 --- /dev/null +++ b/face_swapper/src/networks/masknet.py @@ -0,0 +1,111 @@ +from configparser import ConfigParser + +import torch +from torch import Tensor, nn + +from ..types import Feature, Mask + + +class MaskNet(nn.Module): + def __init__(self, config_parser : ConfigParser) -> None: + super().__init__() + self.config_input_channels = config_parser.getint('training.model.masker', 'input_channels') + self.config_output_channels = config_parser.getint('training.model.masker', 'output_channels') + self.config_num_filters = config_parser.getint('training.model.masker', 'num_filters') + self.down_samples = self.create_down_samples(self.config_input_channels, self.config_num_filters) + self.up_samples = self.create_up_samples(self.config_num_filters) + self.bottleneck = BottleNeck(self.config_num_filters * 4) + self.conv = nn.Conv2d(self.config_num_filters, self.config_output_channels, kernel_size = 1) + self.sigmoid = nn.Sigmoid() + + @staticmethod + def create_down_samples(input_channels : int, num_filters : int) -> nn.ModuleList: + return nn.ModuleList( + [ + DownSample(input_channels, num_filters), + DownSample(num_filters, num_filters * 2), + DownSample(num_filters * 2, num_filters * 4) + ]) + + @staticmethod + def create_up_samples(num_filters : int) -> nn.ModuleList: + return nn.ModuleList( + [ + UpSample(num_filters * 4, num_filters * 2), + UpSample(num_filters * 2, num_filters), + UpSample(num_filters, num_filters) + ]) + + def forward(self, input_tensor : Tensor, input_feature : Feature) -> Mask: + output_mask = torch.cat([ input_tensor, input_feature ], dim = 1) + + for down_sample in self.down_samples: + output_mask = down_sample(output_mask) + + output_mask = self.bottleneck(output_mask) + + for up_sample in self.up_samples: + output_mask = up_sample(output_mask) + + output_mask = self.conv(output_mask) + output_mask = self.sigmoid(output_mask) + return output_mask + + +class BottleNeck(nn.Module): + def __init__(self, num_filters : int): + super().__init__() + self.sequences = self.create_sequences(num_filters) + self.relu = nn.ReLU(inplace = True) + + @staticmethod + def create_sequences(num_filters : int) -> nn.Sequential: + return nn.Sequential( + nn.Conv2d(num_filters, num_filters, kernel_size = 3, padding = 1, bias = False), + nn.BatchNorm2d(num_filters), + nn.ReLU(inplace = True), + nn.Conv2d(num_filters, num_filters, kernel_size = 3, padding = 1, bias = False), + nn.BatchNorm2d(num_filters), + nn.ReLU(inplace = True) + ) + + def forward(self, input_tensor : Tensor) -> Tensor: + output_tensor = self.sequences(input_tensor) + input_tensor + output_tensor = self.relu(output_tensor) + return output_tensor + + +class UpSample(nn.Module): + def __init__(self, input_channels : int, output_channels : int) -> None: + super().__init__() + self.sequences = self.create_sequences(input_channels, output_channels) + + @staticmethod + def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential: + return nn.Sequential( + nn.ConvTranspose2d(input_channels, output_channels, kernel_size = 2, stride = 2), + nn.ReLU(inplace = True) + ) + + def forward(self, input_tensor : Tensor) -> Tensor: + output_tensor = self.sequences(input_tensor) + return output_tensor + + +class DownSample(nn.Module): + def __init__(self, input_channels : int, output_channels : int) -> None: + super().__init__() + self.sequences = self.create_sequences(input_channels, output_channels) + + @staticmethod + def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential: + return nn.Sequential( + nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1, bias = False), + nn.BatchNorm2d(output_channels), + nn.ReLU(inplace = True), + nn.MaxPool2d(2) + ) + + def forward(self, input_tensor : Tensor) -> Tensor: + output_tensor = self.sequences(input_tensor) + return output_tensor diff --git a/face_swapper/src/networks/nld.py b/face_swapper/src/networks/nld.py new file mode 100644 index 0000000..cf2ffd9 --- /dev/null +++ b/face_swapper/src/networks/nld.py @@ -0,0 +1,48 @@ +import math +from configparser import ConfigParser + +from torch import Tensor, nn + + +class NLD(nn.Module): + def __init__(self, config_parser : ConfigParser) -> None: + super().__init__() + self.config_input_channels = config_parser.getint('training.model.discriminator', 'input_channels') + self.config_num_filters = config_parser.getint('training.model.discriminator', 'num_filters') + self.config_kernel_size = config_parser.getint('training.model.discriminator', 'kernel_size') + self.config_num_layers = config_parser.getint('training.model.discriminator', 'num_layers') + self.layers = self.create_layers() + self.sequences = nn.Sequential(*self.layers) + + def create_layers(self) -> nn.ModuleList: + padding = math.ceil((self.config_kernel_size - 1) / 2) + current_filters = self.config_num_filters + layers = nn.ModuleList( + [ + nn.Conv2d(self.config_input_channels, current_filters, kernel_size = self.config_kernel_size, stride = 2, padding = padding), + nn.LeakyReLU(0.2, True) + ]) + + for _ in range(1, self.config_num_layers): + previous_filters = current_filters + current_filters = min(current_filters * 2, 512) + layers +=\ + [ + nn.Conv2d(previous_filters, current_filters, kernel_size = self.config_kernel_size, stride = 2, padding = padding), + nn.InstanceNorm2d(current_filters), + nn.LeakyReLU(0.2, True) + ] + + previous_filters = current_filters + current_filters = min(current_filters * 2, 512) + layers +=\ + [ + nn.Conv2d(previous_filters, current_filters, kernel_size = self.config_kernel_size, padding = padding), + nn.InstanceNorm2d(current_filters), + nn.LeakyReLU(0.2, True), + nn.Conv2d(current_filters, 1, kernel_size = self.config_kernel_size, padding = padding) + ] + return layers + + def forward(self, input_tensor : Tensor) -> Tensor: + return self.sequences(input_tensor) diff --git a/face_swapper/src/networks/unet.py b/face_swapper/src/networks/unet.py new file mode 100644 index 0000000..5d63561 --- /dev/null +++ b/face_swapper/src/networks/unet.py @@ -0,0 +1,157 @@ +from configparser import ConfigParser +from typing import Tuple + +import torch +from torch import Tensor, nn + +from ..types import Feature + + +class UNet(nn.Module): + def __init__(self, config_parser : ConfigParser) -> None: + super().__init__() + self.config_output_size = config_parser.getint('training.model.generator', 'output_size') + self.down_samples = self.create_down_samples() + self.up_samples = self.create_up_samples() + + def create_down_samples(self) -> nn.ModuleList: + down_samples = nn.ModuleList( + [ + DownSample(3, 32), + DownSample(32, 64), + DownSample(64, 128), + DownSample(128, 256), + DownSample(256, 512) + ]) + + if self.config_output_size == 128: + down_samples.extend( + [ + DownSample(512, 512) + ]) + + if self.config_output_size == 256: + down_samples.extend( + [ + DownSample(512, 1024), + DownSample(1024, 1024) + ]) + + if self.config_output_size == 512: + down_samples.extend( + [ + DownSample(512, 1024), + DownSample(1024, 2048), + DownSample(2048, 2048) + ]) + + if self.config_output_size == 1024: + down_samples.extend( + [ + DownSample(512, 1024), + DownSample(1024, 2048), + DownSample(2048, 4096), + DownSample(4096, 4096) + ]) + + return down_samples + + def create_up_samples(self) -> nn.ModuleList: + up_samples = nn.ModuleList() + + if self.config_output_size == 128: + up_samples.extend( + [ + UpSample(512, 512) + ]) + + if self.config_output_size == 256: + up_samples.extend( + [ + UpSample(1024, 1024), + UpSample(2048, 512) + ]) + + if self.config_output_size == 512: + up_samples.extend( + [ + UpSample(2048, 2048), + UpSample(4096, 1024), + UpSample(2048, 512) + ]) + + if self.config_output_size == 1024: + up_samples.extend( + [ + UpSample(4096, 4096), + UpSample(8192, 2048), + UpSample(4096, 1024), + UpSample(2048, 512) + ]) + + up_samples.extend( + [ + UpSample(1024, 256), + UpSample(512, 128), + UpSample(256, 64), + UpSample(128, 32) + ]) + + return up_samples + + def forward(self, target_tensor : Tensor) -> Tuple[Feature, ...]: + down_features = [] + up_features = [] + temp_feature = target_tensor + + for down_sample in self.down_samples: + temp_feature = down_sample(temp_feature) + down_features.append(temp_feature) + + bottleneck_feature = down_features[-1] + temp_feature = bottleneck_feature + + for index, up_sample in enumerate(self.up_samples): + skip_tensor = down_features[-(index + 2)] + temp_feature = up_sample(temp_feature, skip_tensor) + up_features.append(temp_feature) + + final_feature = nn.functional.interpolate(temp_feature, scale_factor = 2, mode = 'bilinear', align_corners = False) + return bottleneck_feature, *up_features, final_feature + + +class UpSample(nn.Module): + def __init__(self, input_channels : int, output_channels : int) -> None: + super().__init__() + self.sequences = self.create_sequences(input_channels, output_channels) + + @staticmethod + def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential: + return nn.Sequential( + nn.ConvTranspose2d(input_channels, output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False), + nn.BatchNorm2d(output_channels), + nn.LeakyReLU(0.1, inplace = True) + ) + + def forward(self, input_tensor : Tensor, skip_tensor : Tensor) -> Tensor: + output_tensor = self.sequences(input_tensor) + output_tensor = torch.cat((output_tensor, skip_tensor), dim = 1) + return output_tensor + + +class DownSample(nn.Module): + def __init__(self, input_channels : int, output_channels : int) -> None: + super().__init__() + self.sequences = self.create_sequences(input_channels, output_channels) + + @staticmethod + def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential: + return nn.Sequential( + nn.Conv2d(input_channels, output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False), + nn.BatchNorm2d(output_channels), + nn.LeakyReLU(0.1, inplace = True) + ) + + def forward(self, input_tensor : Tensor) -> Tensor: + output_tensor = self.sequences(input_tensor) + return output_tensor diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py new file mode 100644 index 0000000..67662b4 --- /dev/null +++ b/face_swapper/src/training.py @@ -0,0 +1,248 @@ +import os +import warnings +from configparser import ConfigParser +from typing import List, Tuple + +import torch +import torchvision +from lightning import LightningModule, Trainer +from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.loggers import TensorBoardLogger +from torch import Tensor, nn +from torch.utils.data import ConcatDataset, Dataset, random_split +from torchdata.stateful_dataloader import StatefulDataLoader + +from .dataset import DynamicDataset +from .helper import calc_embedding, overlay_mask +from .models.discriminator import Discriminator +from .models.generator import Generator +from .models.loss import AdversarialLoss, CycleLoss, DiscriminatorLoss, FeatureLoss, GazeLoss, IdentityLoss, MaskLoss, ReconstructionLoss +from .types import Batch, Embedding, Mask, OptimizerSet + +warnings.filterwarnings('ignore', category = UserWarning, module = 'torch') + +CONFIG_PARSER = ConfigParser() +CONFIG_PARSER.read('config.ini') + + +class FaceSwapperTrainer(LightningModule): + def __init__(self, config_parser : ConfigParser) -> None: + super().__init__() + self.config_generator_embedder_path = config_parser.get('training.model', 'generator_embedder_path') + self.config_loss_embedder_path = config_parser.get('training.model', 'loss_embedder_path') + self.config_gazer_path = config_parser.get('training.model', 'gazer_path') + self.config_face_masker_path = config_parser.get('training.model', 'face_masker_path') + self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size') + self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate') + self.config_preview_frequency = config_parser.getint('training.trainer', 'preview_frequency') + self.generator_embedder = torch.jit.load(self.config_generator_embedder_path, map_location = 'cpu').eval() + self.loss_embedder = torch.jit.load(self.config_loss_embedder_path, map_location = 'cpu').eval() + self.gazer = torch.jit.load(self.config_gazer_path, map_location = 'cpu').eval() + self.face_masker = torch.jit.load(self.config_face_masker_path, map_location ='cpu').eval() + self.generator = Generator(config_parser) + self.discriminator = Discriminator(config_parser) + self.discriminator_loss = DiscriminatorLoss() + self.adversarial_loss = AdversarialLoss(config_parser) + self.cycle_loss = CycleLoss(config_parser) + self.feature_loss = FeatureLoss(config_parser) + self.reconstruction_loss = ReconstructionLoss(config_parser, self.loss_embedder) + self.identity_loss = IdentityLoss(config_parser, self.loss_embedder) + self.gaze_loss = GazeLoss(config_parser, self.gazer) + self.mask_loss = MaskLoss(config_parser, self.face_masker) + self.automatic_optimization = False + + def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Mask]: + with torch.no_grad(): + generator_target_features = self.generator.encode_features(target_tensor) + output_tensor, output_mask = self.generator(source_embedding, target_tensor, generator_target_features) + + return output_tensor, output_mask + + def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet]: + generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4) + discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4) + generator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(generator_optimizer, T_0 = 300, T_mult = 2) + discriminator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(discriminator_optimizer, T_0 = 300, T_mult = 2) + + generator_config =\ + { + 'optimizer': generator_optimizer, + 'lr_scheduler': + { + 'scheduler': generator_scheduler, + 'interval': 'step' + } + } + discriminator_config =\ + { + 'optimizer': discriminator_optimizer, + 'lr_scheduler': + { + 'scheduler': discriminator_scheduler, + 'interval': 'step' + } + } + return generator_config, discriminator_config + + def training_step(self, batch : Batch, batch_index : int) -> Tensor: + source_tensor, target_tensor = batch + do_update = (batch_index + 1) % self.config_accumulate_size == 0 + generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined] + + source_embedding = calc_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0)) + target_embedding = calc_embedding(self.generator_embedder, target_tensor, (0, 0, 0, 0)) + generator_target_features = self.generator.encode_features(target_tensor) + generator_output_tensor, generator_output_mask = self.generator(source_embedding, target_tensor, generator_target_features) + generator_output_features = self.generator.encode_features(generator_output_tensor) + cycle_output_tensor, cycle_output_mask = self.generator(target_embedding, generator_output_tensor, generator_output_features) + cycle_output_features = self.generator.encode_features(cycle_output_tensor) + discriminator_output_tensors = self.discriminator(generator_output_tensor) + adversarial_loss, weighted_adversarial_loss = self.adversarial_loss(discriminator_output_tensors) + cycle_loss, weighted_cycle_loss = self.cycle_loss(target_tensor, cycle_output_tensor, generator_target_features, cycle_output_features) + feature_loss, weighted_feature_loss = self.feature_loss(generator_target_features, generator_output_features) + reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss(source_tensor, target_tensor, generator_output_tensor) + identity_loss, weighted_identity_loss = self.identity_loss(generator_output_tensor, source_tensor) + gaze_loss, weighted_gaze_loss = self.gaze_loss(target_tensor, generator_output_tensor) + mask_loss, weighted_mask_loss = self.mask_loss(target_tensor, generator_output_mask) + generator_loss = weighted_adversarial_loss + weighted_cycle_loss + weighted_feature_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_gaze_loss + weighted_mask_loss + + discriminator_source_tensors = self.discriminator(source_tensor) + discriminator_output_tensors = self.discriminator(generator_output_tensor.detach()) + discriminator_loss = self.discriminator_loss(discriminator_source_tensors, discriminator_output_tensors) + + self.toggle_optimizer(generator_optimizer) + self.manual_backward(generator_loss) + + if do_update: + generator_optimizer.step() + generator_optimizer.zero_grad() + self.untoggle_optimizer(generator_optimizer) + + self.toggle_optimizer(discriminator_optimizer) + self.manual_backward(discriminator_loss) + + if do_update: + discriminator_optimizer.step() + discriminator_optimizer.zero_grad() + self.untoggle_optimizer(discriminator_optimizer) + + if self.global_step % self.config_preview_frequency == 0: + self.generate_preview(source_tensor, target_tensor, generator_output_tensor, generator_output_mask) + + self.log('generator_loss', generator_loss, prog_bar = True) + self.log('discriminator_loss', discriminator_loss, prog_bar = True) + self.log('adversarial_loss', adversarial_loss) + self.log('cycle_loss', cycle_loss) + self.log('feature_loss', feature_loss) + self.log('reconstruction_loss', reconstruction_loss) + self.log('identity_loss', identity_loss) + self.log('gaze_loss', gaze_loss) + self.log('mask_loss', mask_loss) + return generator_loss + + def validation_step(self, batch : Batch, batch_index : int) -> Tensor: + source_tensor, target_tensor = batch + source_embedding = calc_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0)) + output_tensor, _ = self.forward(source_embedding, target_tensor) + output_embedding = calc_embedding(self.generator_embedder, output_tensor, (0, 0, 0, 0)) + validation_score = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5 + self.log('validation_score', validation_score, sync_dist = True, prog_bar = True) + return validation_score + + def generate_preview(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor, output_mask : Mask) -> None: + preview_limit = 8 + preview_cells = [] + overlay_tensor = overlay_mask(output_tensor, output_mask) + + for source_tensor, target_tensor, output_tensor, overlay_tensor in zip(source_tensor[:preview_limit], target_tensor[:preview_limit], output_tensor[:preview_limit], overlay_tensor[:preview_limit]): + preview_cell = torch.cat([ source_tensor, target_tensor, output_tensor, overlay_tensor ], dim = 2) + preview_cells.append(preview_cell) + + preview_cells = torch.cat(preview_cells, dim = 1).unsqueeze(0) + preview_grid = torchvision.utils.make_grid(preview_cells, normalize = True, scale_each = True) + self.logger.experiment.add_image('preview', preview_grid, self.global_step) # type:ignore[attr-defined] + + +def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]: + config_batch_size = CONFIG_PARSER.getint('training.loader', 'batch_size') + config_num_workers = CONFIG_PARSER.getint('training.loader', 'num_workers') + + training_dataset, validate_dataset = split_dataset(dataset) + training_loader = StatefulDataLoader(training_dataset, batch_size = config_batch_size, shuffle = True, num_workers = config_num_workers, drop_last = True, pin_memory = True, persistent_workers = True) + validation_loader = StatefulDataLoader(validate_dataset, batch_size = config_batch_size, shuffle = False, num_workers = config_num_workers, pin_memory = True, persistent_workers = True) + return training_loader, validation_loader + + +def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[Tensor]]: + config_split_ratio = CONFIG_PARSER.getfloat('training.loader', 'split_ratio') + + dataset_size = len(dataset) # type:ignore[arg-type] + training_size = int(dataset_size * config_split_ratio) + validation_size = int(dataset_size - training_size) + training_dataset, validate_dataset = random_split(dataset, [ training_size, validation_size ]) + return training_dataset, validate_dataset + + +def prepare_datasets(config_parser : ConfigParser) -> List[Dataset[Tensor]]: + datasets = [] + + for config_section in config_parser.sections(): + + if config_section.startswith('training.dataset'): + current_config_parser = ConfigParser() + current_config_parser.add_section('training.dataset') + + for key, value in config_parser.items(config_section): + current_config_parser.set('training.dataset', key, value) + + datasets.append(DynamicDataset(current_config_parser)) + + return datasets + + +def create_trainer() -> Trainer: + config_max_epochs = CONFIG_PARSER.getint('training.trainer', 'max_epochs') + config_strategy = CONFIG_PARSER.get('training.trainer', 'strategy') + config_precision = CONFIG_PARSER.get('training.trainer', 'precision') + config_logger_path = CONFIG_PARSER.get('training.trainer', 'logger_path') + config_logger_name = CONFIG_PARSER.get('training.trainer', 'logger_name') + config_directory_path = CONFIG_PARSER.get('training.output', 'directory_path') + config_file_pattern = CONFIG_PARSER.get('training.output', 'file_pattern') + logger = TensorBoardLogger(config_logger_path, config_logger_name) + + return Trainer( + logger = logger, + log_every_n_steps = 10, + max_epochs = config_max_epochs, + strategy = config_strategy, + precision = config_precision, + callbacks = + [ + ModelCheckpoint( + monitor = 'generator_loss', + dirpath = config_directory_path, + filename = config_file_pattern, + every_n_train_steps = 1000, + save_top_k = 3, + save_last = True + ) + ], + val_check_interval = 1000 + ) + + +def train() -> None: + config_resume_path = CONFIG_PARSER.get('training.output', 'resume_path') + + if torch.cuda.is_available(): + torch.set_float32_matmul_precision('high') + + dataset = ConcatDataset(prepare_datasets(CONFIG_PARSER)) + training_loader, validation_loader = create_loaders(dataset) + face_swapper_trainer = FaceSwapperTrainer(CONFIG_PARSER) + trainer = create_trainer() + + if os.path.isfile(config_resume_path): + trainer.fit(face_swapper_trainer, training_loader, validation_loader, ckpt_path = config_resume_path) + else: + trainer.fit(face_swapper_trainer, training_loader, validation_loader) diff --git a/face_swapper/src/types.py b/face_swapper/src/types.py new file mode 100644 index 0000000..d840897 --- /dev/null +++ b/face_swapper/src/types.py @@ -0,0 +1,24 @@ +from typing import Any, Dict, Literal, Tuple, TypeAlias + +from torch import Tensor +from torch.nn import Module + +Batch : TypeAlias = Tuple[Tensor, Tensor] +BatchMode = Literal['equal', 'same', 'different'] + +Feature : TypeAlias = Tensor +Embedding : TypeAlias = Tensor +Mask : TypeAlias = Tensor +Loss : TypeAlias = Tensor + +Padding : TypeAlias = Tuple[int, int, int, int] + +GeneratorModule : TypeAlias = Module +EmbedderModule : TypeAlias = Module +GazerModule : TypeAlias = Module +FaceMaskerModule : TypeAlias = Module + +OptimizerSet : TypeAlias = Any + +WarpTemplate = Literal['arcface_128_v2_to_arcface_112_v2', 'ffhq_to_arcface_128_v2', 'vgg_face_hq_to_arcface_128_v2'] +WarpTemplateSet : TypeAlias = Dict[WarpTemplate, Tensor] diff --git a/face_swapper/tests/test_networks.py b/face_swapper/tests/test_networks.py new file mode 100644 index 0000000..654a172 --- /dev/null +++ b/face_swapper/tests/test_networks.py @@ -0,0 +1,57 @@ +from configparser import ConfigParser + +import pytest +import torch + +from face_swapper.src.networks.aad import AAD +from face_swapper.src.networks.masknet import MaskNet +from face_swapper.src.networks.unet import UNet + + +@pytest.mark.parametrize('output_size', [ 128, 256, 512 ]) +def test_aad_with_unet(output_size : int) -> None: + config_parser = ConfigParser() + config_parser.read_dict( + { + 'training.model.generator': + { + 'source_channels': '512', + 'output_channels': str(output_size * 16), + 'output_size': str(output_size), + 'num_blocks': '2' + } + }) + + encoder = UNet(config_parser).eval() + generator = AAD(config_parser).eval() + + source_tensor = torch.randn(1, 512) + target_tensor = torch.randn(1, 3, output_size, output_size) + + target_features = encoder(target_tensor) + output_tensor = generator(source_tensor, target_features) + + assert output_tensor.shape == (1, 3, output_size, output_size) + + +@pytest.mark.parametrize('output_size', [ 128, 256, 512 ]) +def test_mask_net(output_size : int) -> None: + config_parser = ConfigParser() + config_parser.read_dict( + { + 'training.model.masker': + { + 'input_channels': '67', + 'output_channels': '1', + 'num_filters': '16' + } + }) + + masker = MaskNet(config_parser).eval() + + target_tensor = torch.randn(1, 3, output_size, output_size) + target_feature = torch.randn(1, 64, output_size, output_size) + + output_mask = masker(target_tensor, target_feature) + + assert output_mask.shape == (1, 1, output_size, output_size) diff --git a/face_swapper/train.py b/face_swapper/train.py new file mode 100644 index 0000000..3591cc8 --- /dev/null +++ b/face_swapper/train.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from src.training import train + +if __name__ == '__main__': + train() diff --git a/mypy.ini b/mypy.ini index 64218bc..182b87b 100644 --- a/mypy.ini +++ b/mypy.ini @@ -5,3 +5,4 @@ disallow_untyped_calls = True disallow_untyped_defs = True ignore_missing_imports = True strict_optional = False +explicit_package_bases = True diff --git a/requirements.txt b/requirements.txt index ec67d61..c354c4c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,10 @@ -lightning==2.4.0 -numpy==1.26.4 +--extra-index-url https://download.pytorch.org/whl/cu124 +albumentations==2.0.5 +lightning==2.5.1 onnx==1.17.0 -onnxruntime==1.20.0 -opencv-python==4.10.0.84 -mxnet==1.9.1 +onnxruntime==1.21.0 +pytorch-msssim==1.0.0 +torch==2.6.0 +torchdata==0.11.0 +torchvision==0.21.0 +tensorboard==2.19.0