mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
@@ -1,7 +1,5 @@
|
||||
[flake8]
|
||||
select = E3, E4, F, I1, I2
|
||||
select = E22, E23, E24, E27, E3, E4, E7, F, I1, I2
|
||||
plugins = flake8-import-order
|
||||
application_import_names = arcface_converter
|
||||
application_import_names = embedding_converter, face_swapper
|
||||
import-order-style = pycharm
|
||||
per-file-ignores = preparing.py:E402
|
||||
|
||||
|
||||
|
Before Width: | Height: | Size: 1.3 MiB After Width: | Height: | Size: 1.3 MiB |
Binary file not shown.
|
After Width: | Height: | Size: 5.2 MiB |
@@ -8,12 +8,24 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up Python 3.10
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
python-version: '3.12'
|
||||
- run: pip install flake8
|
||||
- run: pip install flake8-import-order
|
||||
- run: pip install mypy
|
||||
- run: flake8 arcface_converter
|
||||
- run: mypy arcface_converter
|
||||
- run: flake8 embedding_converter face_swapper
|
||||
- run: mypy embedding_converter face_swapper
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
- run: pip install torch torchvision
|
||||
- run: pip install pytest
|
||||
- run: PYTHONPATH=/home/runner/work/facefusion-labs/facefusion-labs pytest
|
||||
|
||||
@@ -1,3 +1,10 @@
|
||||
__pycache__
|
||||
.assets
|
||||
.datasets
|
||||
.idea
|
||||
.inputs
|
||||
.exports
|
||||
.logs
|
||||
.models
|
||||
.outputs
|
||||
.vscode
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
MIT license
|
||||
|
||||
Copyright (c) 2024 Henry Ruhs
|
||||
@@ -1,93 +0,0 @@
|
||||
ArcFace Converter
|
||||
=================
|
||||
|
||||
> Convert face embeddings between various ArcFace models.
|
||||
|
||||

|
||||
|
||||
|
||||
Preview
|
||||
-------
|
||||
|
||||

|
||||
|
||||
|
||||
Installation
|
||||
------------
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
This example utilizes the MegaFace dataset to train an ArcFace Converter for SimSwap.
|
||||
|
||||
```
|
||||
[preparing.dataset]
|
||||
dataset_path = datasets/megaface/train.rec
|
||||
crop_size = 112
|
||||
process_limit = 650000
|
||||
|
||||
[preparing.model]
|
||||
source_path = models/arcface_w600k_r50.onnx
|
||||
target_path = models/arcface_simswap.onnx
|
||||
|
||||
[preparing.input]
|
||||
directory_path = inputs
|
||||
source_path = inputs/arcface_w600k_r50.npy
|
||||
target_path = inputs/arcface_simswap.npy
|
||||
|
||||
[training.loader]
|
||||
split_ratio = 0.8
|
||||
batch_size = 51200
|
||||
num_workers = 8
|
||||
|
||||
[training.trainer]
|
||||
max_epochs = 4096
|
||||
|
||||
[training.output]
|
||||
directory_path = outputs
|
||||
file_pattern = arcface_converter_simswap_{epoch:02d}_{val_loss:.4f}
|
||||
|
||||
[exporting]
|
||||
directory_path = exports
|
||||
source_path = outputs/last.ckpt
|
||||
target_path = exports/arcface_converter_simswap.onnx
|
||||
opset_version = 15
|
||||
|
||||
[execution]
|
||||
providers = CUDAExecutionProvider
|
||||
```
|
||||
|
||||
|
||||
Preparing
|
||||
---------
|
||||
|
||||
Prepare the face embedding pairs.
|
||||
|
||||
```
|
||||
python prepare.py
|
||||
```
|
||||
|
||||
|
||||
Training
|
||||
--------
|
||||
|
||||
Train the ArcFace converter model.
|
||||
|
||||
```
|
||||
python train.py
|
||||
```
|
||||
|
||||
|
||||
Exporting
|
||||
---------
|
||||
|
||||
Export the model to ONNX.
|
||||
|
||||
```
|
||||
python export.py
|
||||
```
|
||||
@@ -1,22 +0,0 @@
|
||||
import configparser
|
||||
from os import makedirs
|
||||
|
||||
import torch
|
||||
|
||||
from .training import ArcFaceConverterTrainer
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
|
||||
def export() -> None:
|
||||
directory_path = CONFIG.get('exporting', 'directory_path')
|
||||
source_path = CONFIG.get('exporting', 'source_path')
|
||||
target_path = CONFIG.get('exporting', 'target_path')
|
||||
opset_version = CONFIG.getint('exporting', 'opset_version')
|
||||
|
||||
makedirs(directory_path, exist_ok = True)
|
||||
model = ArcFaceConverterTrainer.load_from_checkpoint(source_path, map_location = 'cpu')
|
||||
model.eval()
|
||||
input_tensor = torch.randn(1, 512)
|
||||
torch.onnx.export(model, input_tensor, target_path, input_names = [ 'input' ], output_names = [ 'output' ], opset_version = opset_version)
|
||||
@@ -1,21 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class ArcFaceConverter(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super(ArcFaceConverter, self).__init__()
|
||||
self.fc1 = nn.Linear(512, 1024)
|
||||
self.fc2 = nn.Linear(1024, 2048)
|
||||
self.fc3 = nn.Linear(2048, 1024)
|
||||
self.fc4 = nn.Linear(1024, 512)
|
||||
self.activation = nn.LeakyReLU()
|
||||
|
||||
def forward(self, inputs : Tensor) -> Tensor:
|
||||
norm_inputs = inputs / torch.norm(inputs)
|
||||
outputs = self.activation(self.fc1(norm_inputs))
|
||||
outputs = self.activation(self.fc2(outputs))
|
||||
outputs = self.activation(self.fc3(outputs))
|
||||
outputs = self.fc4(outputs)
|
||||
return outputs
|
||||
@@ -1,79 +0,0 @@
|
||||
import configparser
|
||||
from os import makedirs
|
||||
from os.path import isfile
|
||||
from typing import List
|
||||
|
||||
import numpy
|
||||
numpy.bool = numpy.bool_
|
||||
from mxnet.io import ImageRecordIter
|
||||
from onnxruntime import InferenceSession
|
||||
from tqdm import tqdm
|
||||
|
||||
from .typing import Embedding, EmbeddingPairs, VisionFrame
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
|
||||
def prepare_crop_vision_frame(crop_vision_frame : VisionFrame) -> VisionFrame:
|
||||
crop_vision_frame = crop_vision_frame.astype(numpy.float32) / 255
|
||||
crop_vision_frame = (crop_vision_frame - 0.5) * 2
|
||||
return crop_vision_frame
|
||||
|
||||
|
||||
def create_inference_session(model_path : str, execution_providers : List[str]) -> InferenceSession:
|
||||
inference_session = InferenceSession(model_path, providers = execution_providers)
|
||||
return inference_session
|
||||
|
||||
|
||||
def forward(inference_session : InferenceSession, crop_vision_frame : VisionFrame) -> Embedding:
|
||||
embedding = inference_session.run(None,
|
||||
{
|
||||
'input': crop_vision_frame
|
||||
})[0]
|
||||
|
||||
return embedding
|
||||
|
||||
|
||||
def process_embeddings(dataset_reader : ImageRecordIter, source_inference_session : InferenceSession, target_inference_session : InferenceSession) -> EmbeddingPairs:
|
||||
dataset_process_limit = CONFIG.getint('preparing.dataset', 'process_limit')
|
||||
embedding_pairs = []
|
||||
|
||||
with tqdm(total = dataset_process_limit) as progress:
|
||||
for batch in dataset_reader:
|
||||
crop_vision_frame = batch.data[0].asnumpy()
|
||||
crop_vision_frame = prepare_crop_vision_frame(crop_vision_frame)
|
||||
source_embedding = forward(source_inference_session, crop_vision_frame)
|
||||
target_embedding = forward(target_inference_session, crop_vision_frame)
|
||||
embedding_pairs.append([ source_embedding, target_embedding ])
|
||||
progress.update()
|
||||
|
||||
if progress.n == dataset_process_limit:
|
||||
return numpy.concatenate(embedding_pairs, axis = 1).T
|
||||
|
||||
return numpy.concatenate(embedding_pairs, axis = 1).T
|
||||
|
||||
|
||||
def prepare() -> None:
|
||||
dataset_path = CONFIG.get('preparing.dataset', 'dataset_path')
|
||||
dataset_crop_size = CONFIG.getint('preparing.dataset', 'crop_size')
|
||||
model_source_path = CONFIG.get('preparing.model', 'source_path')
|
||||
model_target_path = CONFIG.get('preparing.model', 'target_path')
|
||||
input_directory_path = CONFIG.get('preparing.input', 'directory_path')
|
||||
input_source_path = CONFIG.get('preparing.input', 'source_path')
|
||||
input_target_path = CONFIG.get('preparing.input', 'target_path')
|
||||
execution_providers = CONFIG.get('execution', 'providers').split(' ')
|
||||
|
||||
makedirs(input_directory_path, exist_ok = True)
|
||||
if isfile(dataset_path) and isfile(model_source_path) and isfile(model_target_path):
|
||||
dataset_reader = ImageRecordIter(
|
||||
path_imgrec = dataset_path,
|
||||
data_shape = (3, dataset_crop_size, dataset_crop_size),
|
||||
batch_size = 1,
|
||||
shuffle = False
|
||||
)
|
||||
source_inference_session = create_inference_session(model_source_path, execution_providers)
|
||||
target_inference_session = create_inference_session(model_target_path, execution_providers)
|
||||
embedding_pairs = process_embeddings(dataset_reader, source_inference_session, target_inference_session)
|
||||
numpy.save(input_source_path, embedding_pairs[..., 0].T)
|
||||
numpy.save(input_target_path, embedding_pairs[..., 1].T)
|
||||
@@ -1,116 +0,0 @@
|
||||
import configparser
|
||||
from typing import Any, Tuple
|
||||
|
||||
import numpy
|
||||
import pytorch_lightning
|
||||
import torch
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.tuner.tuning import Tuner
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader, Dataset, TensorDataset, random_split
|
||||
|
||||
from .model import ArcFaceConverter
|
||||
from .typing import Batch, Loader
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
|
||||
class ArcFaceConverterTrainer(pytorch_lightning.LightningModule):
|
||||
def __init__(self) -> None:
|
||||
super(ArcFaceConverterTrainer, self).__init__()
|
||||
self.model = ArcFaceConverter()
|
||||
self.loss_fn = torch.nn.MSELoss()
|
||||
self.lr = 0.001
|
||||
|
||||
def forward(self, source_embedding : Tensor) -> Tensor:
|
||||
return self.model(source_embedding)
|
||||
|
||||
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
source_embedding, target_embedding = batch
|
||||
output_embedding = self(source_embedding)
|
||||
loss = self.loss_fn(output_embedding, target_embedding)
|
||||
self.log('train_loss', loss, prog_bar = True, logger = True)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
source_embedding, target_embedding = batch
|
||||
output_embedding = self(source_embedding)
|
||||
loss = self.loss_fn(output_embedding, target_embedding)
|
||||
self.log('val_loss', loss, prog_bar = True, logger = True)
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self) -> Any:
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr = self.lr)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
|
||||
|
||||
return\
|
||||
{
|
||||
'optimizer': optimizer,
|
||||
'lr_scheduler':
|
||||
{
|
||||
'scheduler': scheduler,
|
||||
'monitor': 'train_loss',
|
||||
'interval': 'epoch',
|
||||
'frequency': 1
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def create_loaders() -> Tuple[Loader, Loader]:
|
||||
loader_batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
loader_num_workers = CONFIG.getint('training.loader', 'num_workers')
|
||||
|
||||
training_dataset, validate_dataset = split_dataset()
|
||||
training_loader = DataLoader(training_dataset, batch_size = loader_batch_size, num_workers = loader_num_workers, shuffle = True, pin_memory = True)
|
||||
validation_loader = DataLoader(validate_dataset, batch_size = loader_batch_size, num_workers = loader_num_workers, shuffle = False, pin_memory = True)
|
||||
return training_loader, validation_loader
|
||||
|
||||
|
||||
def split_dataset() -> Tuple[Dataset[Any], Dataset[Any]]:
|
||||
input_source_path = CONFIG.get('preparing.input', 'source_path')
|
||||
input_target_path = CONFIG.get('preparing.input', 'target_path')
|
||||
loader_split_ratio = CONFIG.getfloat('training.loader', 'split_ratio')
|
||||
|
||||
source_input = torch.from_numpy(numpy.load(input_source_path)).float()
|
||||
target_input = torch.from_numpy(numpy.load(input_target_path)).float()
|
||||
dataset = TensorDataset(source_input, target_input)
|
||||
|
||||
dataset_size = len(dataset)
|
||||
training_size = int(loader_split_ratio * len(dataset))
|
||||
validation_size = int(dataset_size - training_size)
|
||||
training_dataset, validate_dataset = random_split(dataset, [ training_size, validation_size ])
|
||||
return training_dataset, validate_dataset
|
||||
|
||||
|
||||
def create_trainer() -> Trainer:
|
||||
trainer_max_epochs = CONFIG.getint('training.trainer', 'max_epochs')
|
||||
output_directory_path = CONFIG.get('training.output', 'directory_path')
|
||||
output_file_pattern = CONFIG.get('training.output', 'file_pattern')
|
||||
|
||||
return Trainer(
|
||||
max_epochs = trainer_max_epochs,
|
||||
callbacks =
|
||||
[
|
||||
ModelCheckpoint(
|
||||
monitor = 'train_loss',
|
||||
dirpath = output_directory_path,
|
||||
filename = output_file_pattern,
|
||||
every_n_epochs = 10,
|
||||
save_top_k = 3,
|
||||
save_last = True
|
||||
)
|
||||
],
|
||||
enable_progress_bar = True,
|
||||
log_every_n_steps = 2
|
||||
)
|
||||
|
||||
|
||||
def train() -> None:
|
||||
trainer = create_trainer()
|
||||
training_loader, validation_loader = create_loaders()
|
||||
model = ArcFaceConverterTrainer()
|
||||
tuner = Tuner(trainer)
|
||||
tuner.lr_find(model, training_loader, validation_loader)
|
||||
trainer.fit(model, training_loader, validation_loader)
|
||||
@@ -1,13 +0,0 @@
|
||||
from typing import Any, Tuple
|
||||
|
||||
from numpy.typing import NDArray
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
Batch = Tuple[Tensor, Tensor]
|
||||
Loader = DataLoader[Tuple[Tensor, ...]]
|
||||
|
||||
Embedding = NDArray[Any]
|
||||
EmbeddingPairs = NDArray[Any]
|
||||
FaceLandmark5 = NDArray[Any]
|
||||
VisionFrame = NDArray[Any]
|
||||
@@ -0,0 +1,3 @@
|
||||
OpenRAIL-MS license
|
||||
|
||||
Copyright (c) 2025 Henry Ruhs
|
||||
@@ -0,0 +1,96 @@
|
||||
Embedding Converter
|
||||
===================
|
||||
|
||||
> Convert face embeddings between various models.
|
||||
|
||||

|
||||
|
||||
|
||||
Preview
|
||||
-------
|
||||
|
||||

|
||||
|
||||
|
||||
Installation
|
||||
------------
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
|
||||
Setup
|
||||
-----
|
||||
|
||||
This `config.ini` utilizes the MegaFace dataset to train the Embedding Converter for SimSwap.
|
||||
|
||||
```
|
||||
[training.dataset]
|
||||
file_pattern = .datasets/megaface/**/*.jpg
|
||||
```
|
||||
|
||||
```
|
||||
[training.loader]
|
||||
batch_size = 256
|
||||
num_workers = 8
|
||||
split_ratio = 0.95
|
||||
```
|
||||
|
||||
```
|
||||
[training.model]
|
||||
source_path = .models/arcface_w600k_r50.pt
|
||||
target_path = .models/arcface_simswap.pt
|
||||
```
|
||||
|
||||
```
|
||||
[training.trainer]
|
||||
learning_rate = 0.001
|
||||
max_epochs = 4096
|
||||
strategy = auto
|
||||
precision = 16-mixed
|
||||
logger_path = .logs
|
||||
logger_name = arcface_converter_simswap
|
||||
```
|
||||
|
||||
```
|
||||
[training.output]
|
||||
directory_path = .outputs
|
||||
file_pattern = arcface_converter_simswap_{epoch}_{step}
|
||||
resume_path = .outputs/last.ckpt
|
||||
```
|
||||
|
||||
```
|
||||
[exporting]
|
||||
directory_path = .exports
|
||||
source_path = .outputs/last.ckpt
|
||||
target_path = .exports/arcface_converter_simswap.onnx
|
||||
ir_version = 10
|
||||
opset_version = 15
|
||||
```
|
||||
|
||||
|
||||
Training
|
||||
--------
|
||||
|
||||
Train the Embedding Converter model.
|
||||
|
||||
```
|
||||
python train.py
|
||||
```
|
||||
|
||||
Launch the TensorBoard to monitor the training.
|
||||
|
||||
```
|
||||
tensorboard --logdir=.logs
|
||||
```
|
||||
|
||||
|
||||
Exporting
|
||||
---------
|
||||
|
||||
Export the model to ONNX.
|
||||
|
||||
```
|
||||
python export.py
|
||||
```
|
||||
@@ -1,34 +1,31 @@
|
||||
[preparing.dataset]
|
||||
dataset_path =
|
||||
crop_size =
|
||||
process_limit =
|
||||
|
||||
[preparing.model]
|
||||
source_path =
|
||||
target_path =
|
||||
|
||||
[preparing.input]
|
||||
directory_path =
|
||||
source_path =
|
||||
target_path =
|
||||
[training.dataset]
|
||||
file_pattern =
|
||||
|
||||
[training.loader]
|
||||
split_ratio =
|
||||
batch_size =
|
||||
num_workers =
|
||||
split_ratio =
|
||||
|
||||
[training.model]
|
||||
source_path =
|
||||
target_path =
|
||||
|
||||
[training.trainer]
|
||||
learning_rate =
|
||||
max_epochs =
|
||||
strategy =
|
||||
precision =
|
||||
logger_path =
|
||||
logger_name =
|
||||
|
||||
[training.output]
|
||||
directory_path =
|
||||
file_pattern =
|
||||
resume_path =
|
||||
|
||||
[exporting]
|
||||
directory_path =
|
||||
source_path =
|
||||
target_path =
|
||||
ir_version =
|
||||
opset_version =
|
||||
|
||||
[execution]
|
||||
providers =
|
||||
@@ -0,0 +1,34 @@
|
||||
import glob
|
||||
from configparser import ConfigParser
|
||||
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import io, transforms
|
||||
|
||||
from .types import Batch
|
||||
|
||||
|
||||
class StaticDataset(Dataset[Tensor]):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
self.config_file_pattern = config_parser.get('training.dataset', 'file_pattern')
|
||||
self.file_paths = glob.glob(self.config_file_pattern)
|
||||
self.transforms = self.compose_transforms()
|
||||
|
||||
def __getitem__(self, index : int) -> Batch:
|
||||
file_path = self.file_paths[index]
|
||||
temp_tensor = io.read_image(file_path)
|
||||
return self.transforms(temp_tensor)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.file_paths)
|
||||
|
||||
@staticmethod
|
||||
def compose_transforms() -> transforms:
|
||||
return transforms.Compose(
|
||||
[
|
||||
transforms.ToPILImage(),
|
||||
transforms.Resize((112, 112), interpolation = transforms.InterpolationMode.BICUBIC),
|
||||
transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
@@ -0,0 +1,23 @@
|
||||
import os
|
||||
from configparser import ConfigParser
|
||||
|
||||
import torch
|
||||
|
||||
from .training import EmbeddingConverterTrainer
|
||||
|
||||
CONFIG_PARSER = ConfigParser()
|
||||
CONFIG_PARSER.read('config.ini')
|
||||
|
||||
|
||||
def export() -> None:
|
||||
config_directory_path = CONFIG_PARSER.get('exporting', 'directory_path')
|
||||
config_source_path = CONFIG_PARSER.get('exporting', 'source_path')
|
||||
config_target_path = CONFIG_PARSER.get('exporting', 'target_path')
|
||||
config_ir_version = CONFIG_PARSER.getint('exporting', 'ir_version')
|
||||
config_opset_version = CONFIG_PARSER.getint('exporting', 'opset_version')
|
||||
|
||||
os.makedirs(config_directory_path, exist_ok = True)
|
||||
model = EmbeddingConverterTrainer.load_from_checkpoint(config_source_path, map_location = 'cpu').eval()
|
||||
model.ir_version = torch.tensor(config_ir_version)
|
||||
input_tensor = torch.randn(1, 512)
|
||||
torch.onnx.export(model, input_tensor, config_target_path, input_names = [ 'input' ], output_names = [ 'output' ], opset_version = config_opset_version)
|
||||
@@ -0,0 +1,28 @@
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class EmbeddingConverter(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.layers = self.create_layers()
|
||||
self.leaky_relu = nn.LeakyReLU()
|
||||
|
||||
@staticmethod
|
||||
def create_layers() -> nn.ModuleList:
|
||||
return nn.ModuleList(
|
||||
[
|
||||
nn.Linear(512, 1024),
|
||||
nn.Linear(1024, 2048),
|
||||
nn.Linear(2048, 1024),
|
||||
nn.Linear(1024, 512)
|
||||
])
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> Tensor:
|
||||
output_tensor = input_tensor / torch.norm(input_tensor)
|
||||
|
||||
for layer in self.layers[:-1]:
|
||||
output_tensor = self.leaky_relu(layer(output_tensor))
|
||||
|
||||
output_tensor = self.layers[-1](output_tensor)
|
||||
return output_tensor
|
||||
@@ -0,0 +1,134 @@
|
||||
import os
|
||||
from configparser import ConfigParser
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from lightning import LightningModule, Trainer
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
from torch import Tensor, nn
|
||||
from torch.utils.data import Dataset, random_split
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from .dataset import StaticDataset
|
||||
from .models.embedding_converter import EmbeddingConverter
|
||||
from .types import Batch, Embedding, OptimizerSet
|
||||
|
||||
CONFIG_PARSER = ConfigParser()
|
||||
CONFIG_PARSER.read('config.ini')
|
||||
|
||||
|
||||
class EmbeddingConverterTrainer(LightningModule):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config_source_path = config_parser.get('training.model', 'source_path')
|
||||
self.config_target_path = config_parser.get('training.model', 'target_path')
|
||||
self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate')
|
||||
self.embedding_converter = EmbeddingConverter()
|
||||
self.source_embedder = torch.jit.load(self.config_source_path, map_location = 'cpu').eval()
|
||||
self.target_embedder = torch.jit.load(self.config_target_path, map_location = 'cpu').eval()
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def forward(self, source_embedding : Embedding) -> Embedding:
|
||||
return self.embedding_converter(source_embedding)
|
||||
|
||||
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
with torch.no_grad():
|
||||
source_embedding = self.source_embedder(batch)
|
||||
target_embedding = self.target_embedder(batch)
|
||||
output_embedding = self(source_embedding)
|
||||
training_loss = self.mse_loss(output_embedding, target_embedding)
|
||||
self.log('training_loss', training_loss, prog_bar = True)
|
||||
return training_loss
|
||||
|
||||
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
with torch.no_grad():
|
||||
source_embedding = self.source_embedder(batch)
|
||||
output_embedding = self(source_embedding)
|
||||
validation_score = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5
|
||||
self.log('validation_score', validation_score, sync_dist = True, prog_bar = True)
|
||||
return validation_score
|
||||
|
||||
def configure_optimizers(self) -> OptimizerSet:
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr = self.config_learning_rate)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
|
||||
optimizer_set =\
|
||||
{
|
||||
'optimizer': optimizer,
|
||||
'lr_scheduler':
|
||||
{
|
||||
'scheduler': scheduler,
|
||||
'monitor': 'training_loss',
|
||||
'interval': 'epoch',
|
||||
'frequency': 1
|
||||
}
|
||||
}
|
||||
|
||||
return optimizer_set
|
||||
|
||||
|
||||
def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]:
|
||||
config_batch_size = CONFIG_PARSER.getint('training.loader', 'batch_size')
|
||||
config_num_workers = CONFIG_PARSER.getint('training.loader', 'num_workers')
|
||||
|
||||
training_dataset, validate_dataset = split_dataset(dataset)
|
||||
training_loader = StatefulDataLoader(training_dataset, batch_size = config_batch_size, shuffle = True, num_workers = config_num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
validation_loader = StatefulDataLoader(validate_dataset, batch_size = config_batch_size, shuffle = False, num_workers = config_num_workers, pin_memory = True, persistent_workers = True)
|
||||
return training_loader, validation_loader
|
||||
|
||||
|
||||
def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[Tensor]]:
|
||||
config_split_ratio = CONFIG_PARSER.getfloat('training.loader', 'split_ratio')
|
||||
|
||||
dataset_size = len(dataset) # type:ignore[arg-type]
|
||||
training_size = int(dataset_size * config_split_ratio)
|
||||
validation_size = int(dataset_size - training_size)
|
||||
training_dataset, validate_dataset = random_split(dataset, [ training_size, validation_size ])
|
||||
return training_dataset, validate_dataset
|
||||
|
||||
|
||||
def create_trainer() -> Trainer:
|
||||
config_max_epochs = CONFIG_PARSER.getint('training.trainer', 'max_epochs')
|
||||
config_strategy = CONFIG_PARSER.get('training.trainer', 'strategy')
|
||||
config_precision = CONFIG_PARSER.get('training.trainer', 'precision')
|
||||
config_logger_path = CONFIG_PARSER.get('training.trainer', 'logger_path')
|
||||
config_logger_name = CONFIG_PARSER.get('training.trainer', 'logger_name')
|
||||
config_directory_path = CONFIG_PARSER.get('training.output', 'directory_path')
|
||||
config_file_pattern = CONFIG_PARSER.get('training.output', 'file_pattern')
|
||||
logger = TensorBoardLogger(config_logger_path, config_logger_name)
|
||||
|
||||
return Trainer(
|
||||
logger = logger,
|
||||
log_every_n_steps = 10,
|
||||
max_epochs = config_max_epochs,
|
||||
strategy = config_strategy,
|
||||
precision = config_precision,
|
||||
callbacks =
|
||||
[
|
||||
ModelCheckpoint(
|
||||
monitor = 'training_loss',
|
||||
dirpath = config_directory_path,
|
||||
filename = config_file_pattern,
|
||||
every_n_epochs = 1,
|
||||
save_top_k = 3,
|
||||
save_last = True
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def train() -> None:
|
||||
config_resume_path = CONFIG_PARSER.get('training.output', 'resume_path')
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
dataset = StaticDataset(CONFIG_PARSER)
|
||||
training_loader, validation_loader = create_loaders(dataset)
|
||||
embedding_converter_trainer = EmbeddingConverterTrainer(CONFIG_PARSER)
|
||||
trainer = create_trainer()
|
||||
|
||||
if os.path.exists(config_resume_path):
|
||||
trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config_resume_path)
|
||||
else:
|
||||
trainer.fit(embedding_converter_trainer, training_loader, validation_loader)
|
||||
@@ -0,0 +1,8 @@
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
Batch : TypeAlias = Tensor
|
||||
Embedding : TypeAlias = Tensor
|
||||
|
||||
OptimizerSet : TypeAlias = Any
|
||||
@@ -0,0 +1,3 @@
|
||||
ResearchRAIL-MS license
|
||||
|
||||
Copyright (c) 2025 Henry Ruhs
|
||||
@@ -0,0 +1,160 @@
|
||||
Face Swapper
|
||||
============
|
||||
|
||||
> Face shape and occlusion aware identity transfer.
|
||||
|
||||

|
||||
|
||||
|
||||
Preview
|
||||
-------
|
||||
|
||||

|
||||
|
||||
|
||||
Installation
|
||||
------------
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
|
||||
Setup
|
||||
-----
|
||||
|
||||
This `config.ini` utilizes the MegaFace dataset to train the Face Swapper model.
|
||||
|
||||
```
|
||||
[training.dataset]
|
||||
file_pattern = .datasets/vggface2/**/*.jpg
|
||||
warp_template = vgg_face_hq_to_arcface_128_v2
|
||||
transform_size = 256
|
||||
batch_mode = equal
|
||||
batch_ratio = 0.2
|
||||
```
|
||||
|
||||
```
|
||||
[training.loader]
|
||||
batch_size = 8
|
||||
num_workers = 8
|
||||
split_ratio = 0.9995
|
||||
```
|
||||
|
||||
```
|
||||
[training.model]
|
||||
generator_embedder_path = .models/blendface.pt
|
||||
loss_embedder_path = .models/adaface.pt
|
||||
gazer_path = .models/gazer.pt
|
||||
face_masker_path = .models/face_masker.pt
|
||||
```
|
||||
|
||||
```
|
||||
[training.model.generator]
|
||||
source_channels = 512
|
||||
output_channels = 4096
|
||||
output_size = 256
|
||||
num_blocks = 2
|
||||
```
|
||||
|
||||
```
|
||||
[training.model.discriminator]
|
||||
input_channels = 3
|
||||
num_filters = 64
|
||||
num_layers = 5
|
||||
num_discriminators = 3
|
||||
kernel_size = 4
|
||||
```
|
||||
|
||||
```
|
||||
[training.model.masker]
|
||||
input_channels = 67
|
||||
output_channels = 1
|
||||
num_filters = 16
|
||||
```
|
||||
|
||||
```
|
||||
[training.losses]
|
||||
adversarial_weight = 1.0
|
||||
cycle_weight = 1.0
|
||||
feature_weight = 10.0
|
||||
reconstruction_weight = 10.0
|
||||
identity_weight = 20.0
|
||||
gaze_weight = 0.05
|
||||
mask_weight = 5.0
|
||||
```
|
||||
|
||||
```
|
||||
[training.trainer]
|
||||
accumulate_size = 4
|
||||
learning_rate = 0.0004
|
||||
max_epochs = 50
|
||||
strategy = auto
|
||||
precision = 16-mixed
|
||||
logger_path = .logs
|
||||
logger_name = face_swapper
|
||||
preview_frequency = 100
|
||||
```
|
||||
|
||||
```
|
||||
[training.output]
|
||||
directory_path = .outputs
|
||||
file_pattern = face_swapper_{epoch}_{step}
|
||||
resume_path = .outputs/last.ckpt
|
||||
```
|
||||
|
||||
```
|
||||
[exporting]
|
||||
directory_path = .exports
|
||||
source_path = .outputs/last.ckpt
|
||||
target_path = .exports/face_swapper.onnx
|
||||
target_size = 256
|
||||
ir_version = 10
|
||||
opset_version = 15
|
||||
precision = full
|
||||
```
|
||||
|
||||
```
|
||||
[inferencing]
|
||||
generator_path = .outputs/last.ckpt
|
||||
embedder_path = .models/arcface.pt
|
||||
source_path = .assets/source.jpg
|
||||
target_path = .assets/target.jpg
|
||||
output_path = .outputs/output.jpg
|
||||
```
|
||||
|
||||
|
||||
Training
|
||||
--------
|
||||
|
||||
Train the Face Swapper model.
|
||||
|
||||
```
|
||||
python train.py
|
||||
```
|
||||
|
||||
Launch the TensorBoard to monitor the training.
|
||||
|
||||
```
|
||||
tensorboard --logdir=.logs
|
||||
```
|
||||
|
||||
|
||||
Exporting
|
||||
---------
|
||||
|
||||
Export the model to ONNX.
|
||||
|
||||
```
|
||||
python export.py
|
||||
```
|
||||
|
||||
|
||||
Inferencing
|
||||
-----------
|
||||
|
||||
Inference the model.
|
||||
|
||||
```
|
||||
python infer.py
|
||||
```
|
||||
@@ -0,0 +1,75 @@
|
||||
[training.dataset]
|
||||
file_pattern =
|
||||
warp_template =
|
||||
transform_size =
|
||||
batch_mode =
|
||||
batch_ratio =
|
||||
|
||||
[training.loader]
|
||||
batch_size =
|
||||
num_workers =
|
||||
split_ratio =
|
||||
|
||||
[training.model]
|
||||
generator_embedder_path =
|
||||
loss_embedder_path =
|
||||
gazer_path =
|
||||
face_masker_path =
|
||||
|
||||
[training.model.generator]
|
||||
source_channels =
|
||||
output_channels =
|
||||
output_size =
|
||||
num_blocks =
|
||||
|
||||
[training.model.discriminator]
|
||||
input_channels =
|
||||
num_filters =
|
||||
num_layers =
|
||||
num_discriminators =
|
||||
kernel_size =
|
||||
|
||||
[training.model.masker]
|
||||
input_channels =
|
||||
output_channels =
|
||||
num_filters =
|
||||
|
||||
[training.losses]
|
||||
adversarial_weight =
|
||||
cycle_weight =
|
||||
feature_weight =
|
||||
reconstruction_weight =
|
||||
identity_weight =
|
||||
gaze_weight =
|
||||
mask_weight =
|
||||
|
||||
[training.trainer]
|
||||
accumulate_size =
|
||||
learning_rate =
|
||||
max_epochs =
|
||||
strategy =
|
||||
precision =
|
||||
logger_path =
|
||||
logger_name =
|
||||
preview_frequency =
|
||||
|
||||
[training.output]
|
||||
directory_path =
|
||||
file_pattern =
|
||||
resume_path =
|
||||
|
||||
[exporting]
|
||||
directory_path =
|
||||
source_path =
|
||||
target_path =
|
||||
target_size =
|
||||
ir_version =
|
||||
opset_version =
|
||||
precision =
|
||||
|
||||
[inferencing]
|
||||
generator_path =
|
||||
embedder_path =
|
||||
source_path =
|
||||
target_path =
|
||||
output_path =
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from src.preparing import prepare
|
||||
from src.exporting import export
|
||||
|
||||
if __name__ == '__main__':
|
||||
prepare()
|
||||
export()
|
||||
@@ -0,0 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from src.inferencing import infer
|
||||
|
||||
if __name__ == '__main__':
|
||||
infer()
|
||||
@@ -0,0 +1,107 @@
|
||||
import glob
|
||||
import os
|
||||
import random
|
||||
from configparser import ConfigParser
|
||||
from typing import cast
|
||||
|
||||
import albumentations
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import io, transforms
|
||||
|
||||
from .helper import warp_tensor
|
||||
from .types import Batch, BatchMode, WarpTemplate
|
||||
|
||||
|
||||
class DynamicDataset(Dataset[Tensor]):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
self.config_file_pattern = config_parser.get('training.dataset', 'file_pattern')
|
||||
self.config_transform_size = config_parser.getint('training.dataset', 'transform_size')
|
||||
self.config_batch_mode = cast(BatchMode, config_parser.get('training.dataset', 'batch_mode'))
|
||||
self.config_batch_ratio = config_parser.getfloat('training.dataset', 'batch_ratio')
|
||||
self.config_parser = config_parser
|
||||
self.file_paths = glob.glob(self.config_file_pattern)
|
||||
self.transforms = self.compose_transforms()
|
||||
|
||||
def __getitem__(self, index : int) -> Batch:
|
||||
file_path = self.file_paths[index]
|
||||
|
||||
if random.random() < self.config_batch_ratio:
|
||||
if self.config_batch_mode == 'equal':
|
||||
return self.prepare_equal_batch(file_path)
|
||||
if self.config_batch_mode == 'same':
|
||||
return self.prepare_same_batch(file_path)
|
||||
|
||||
return self.prepare_different_batch(file_path)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.file_paths)
|
||||
|
||||
def compose_transforms(self) -> transforms:
|
||||
return transforms.Compose(
|
||||
[
|
||||
AugmentTransform(),
|
||||
transforms.ToPILImage(),
|
||||
transforms.Resize((self.config_transform_size, self.config_transform_size), interpolation = transforms.InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
WarpTransform(self.config_parser),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
|
||||
def prepare_different_batch(self, source_path : str) -> Batch:
|
||||
target_path = random.choice(self.file_paths)
|
||||
source_tensor = io.read_image(source_path)
|
||||
source_tensor = self.transforms(source_tensor)
|
||||
target_tensor = io.read_image(target_path)
|
||||
target_tensor = self.transforms(target_tensor)
|
||||
return source_tensor, target_tensor
|
||||
|
||||
def prepare_equal_batch(self, source_path : str) -> Batch:
|
||||
source_tensor = io.read_image(source_path)
|
||||
source_tensor = self.transforms(source_tensor)
|
||||
return source_tensor, source_tensor
|
||||
|
||||
def prepare_same_batch(self, source_path : str) -> Batch:
|
||||
target_directory_path = os.path.dirname(source_path)
|
||||
target_file_name_and_extension = random.choice(os.listdir(target_directory_path))
|
||||
target_path = os.path.join(target_directory_path, target_file_name_and_extension)
|
||||
source_tensor = io.read_image(source_path)
|
||||
source_tensor = self.transforms(source_tensor)
|
||||
target_tensor = io.read_image(target_path)
|
||||
target_tensor = self.transforms(target_tensor)
|
||||
return source_tensor, target_tensor
|
||||
|
||||
|
||||
class AugmentTransform:
|
||||
def __init__(self) -> None:
|
||||
self.transforms = self.compose_transforms()
|
||||
|
||||
def __call__(self, input_tensor : Tensor) -> Tensor:
|
||||
temp_tensor = input_tensor.numpy().transpose(1, 2, 0)
|
||||
return self.transforms(image = temp_tensor).get('image')
|
||||
|
||||
@staticmethod
|
||||
def compose_transforms() -> albumentations.Compose:
|
||||
return albumentations.Compose(
|
||||
[
|
||||
albumentations.HorizontalFlip(),
|
||||
albumentations.OneOf(
|
||||
[
|
||||
albumentations.MotionBlur(p = 0.1),
|
||||
albumentations.ZoomBlur(max_factor = (1.0, 1.1), p = 0.1)
|
||||
], p = 0.2),
|
||||
albumentations.RandomBrightnessContrast(p = 0.7),
|
||||
albumentations.ColorJitter(p = 0.2),
|
||||
albumentations.RGBShift(p = 0.7),
|
||||
albumentations.Illumination(p = 0.2),
|
||||
albumentations.Affine(translate_percent = (-0.03, 0.03), scale = (0.98, 1.02), rotate = (-2, 2), border_mode = 1, p = 0.3)
|
||||
])
|
||||
|
||||
|
||||
class WarpTransform:
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
self.config_warp_template = cast(WarpTemplate, config_parser.get('training.dataset', 'warp_template'))
|
||||
|
||||
def __call__(self, input_tensor : Tensor) -> Tensor:
|
||||
temp_tensor = input_tensor.unsqueeze(0)
|
||||
return warp_tensor(temp_tensor, self.config_warp_template).squeeze(0)
|
||||
@@ -0,0 +1,47 @@
|
||||
import os
|
||||
from configparser import ConfigParser
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .training import FaceSwapperTrainer
|
||||
from .types import Embedding, Mask, Module
|
||||
|
||||
CONFIG_PARSER = ConfigParser()
|
||||
CONFIG_PARSER.read('config.ini')
|
||||
|
||||
|
||||
class HalfPrecision(nn.Module):
|
||||
def __init__(self, model : Module) -> None:
|
||||
super().__init__()
|
||||
self.model = model.half()
|
||||
|
||||
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Mask]:
|
||||
source_embedding = source_embedding.half()
|
||||
target_tensor = target_tensor.half()
|
||||
output_tensor, output_mask = self.model(source_embedding, target_tensor)
|
||||
output_tensor = output_tensor.float()
|
||||
output_mask = output_mask.float()
|
||||
return output_tensor, output_mask
|
||||
|
||||
|
||||
def export() -> None:
|
||||
config_directory_path = CONFIG_PARSER.get('exporting', 'directory_path')
|
||||
config_source_path = CONFIG_PARSER.get('exporting', 'source_path')
|
||||
config_target_path = CONFIG_PARSER.get('exporting', 'target_path')
|
||||
config_target_size = CONFIG_PARSER.getint('exporting', 'target_size')
|
||||
config_ir_version = CONFIG_PARSER.getint('exporting', 'ir_version')
|
||||
config_opset_version = CONFIG_PARSER.getint('exporting', 'opset_version')
|
||||
config_precision = CONFIG_PARSER.get('exporting', 'precision')
|
||||
|
||||
os.makedirs(config_directory_path, exist_ok = True)
|
||||
model = FaceSwapperTrainer.load_from_checkpoint(config_source_path, config_parser = CONFIG_PARSER, map_location = 'cpu').eval()
|
||||
|
||||
if config_precision == 'half':
|
||||
model = HalfPrecision(model).eval()
|
||||
|
||||
model.ir_version = torch.tensor(config_ir_version)
|
||||
source_tensor = torch.randn(1, 512)
|
||||
target_tensor = torch.randn(1, 3, config_target_size, config_target_size)
|
||||
torch.onnx.export(model, (source_tensor, target_tensor), config_target_path, input_names = [ 'source', 'target' ], output_names = [ 'output', 'mask' ], opset_version = config_opset_version)
|
||||
@@ -0,0 +1,51 @@
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .types import EmbedderModule, Embedding, Mask, Padding, WarpTemplate, WarpTemplateSet
|
||||
|
||||
WARP_TEMPLATE_SET : WarpTemplateSet =\
|
||||
{
|
||||
'arcface_128_v2_to_arcface_112_v2': torch.tensor(
|
||||
[
|
||||
[ 8.75000016e-01, -1.07193451e-08, 3.80446920e-10 ],
|
||||
[ 1.07193451e-08, 8.75000016e-01, -1.25000007e-01 ]
|
||||
]),
|
||||
'ffhq_to_arcface_128_v2': torch.tensor(
|
||||
[
|
||||
[ 8.50048894e-01, -1.29486822e-04, 1.90956388e-03 ],
|
||||
[ 1.29486822e-04, 8.50048894e-01, 9.56254653e-02 ]
|
||||
]),
|
||||
'vgg_face_hq_to_arcface_128_v2': torch.tensor(
|
||||
[
|
||||
[ 1.01305414, -0.00140513, -0.00585911 ],
|
||||
[ 0.00140513, 1.01305414, 0.11169602 ]
|
||||
])
|
||||
}
|
||||
|
||||
|
||||
def warp_tensor(input_tensor : Tensor, warp_template : WarpTemplate) -> Tensor:
|
||||
normed_warp_template = WARP_TEMPLATE_SET.get(warp_template).repeat(input_tensor.shape[0], 1, 1)
|
||||
affine_grid = nn.functional.affine_grid(normed_warp_template.to(input_tensor.device), list(input_tensor.shape))
|
||||
output_tensor = nn.functional.grid_sample(input_tensor, affine_grid, align_corners = False, padding_mode = 'reflection')
|
||||
return output_tensor
|
||||
|
||||
|
||||
def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : Padding) -> Embedding:
|
||||
crop_tensor = warp_tensor(input_tensor, 'arcface_128_v2_to_arcface_112_v2')
|
||||
crop_tensor = nn.functional.interpolate(crop_tensor, size = 112, mode = 'area')
|
||||
crop_tensor[:, :, :padding[0], :] = 0
|
||||
crop_tensor[:, :, 112 - padding[1]:, :] = 0
|
||||
crop_tensor[:, :, :, :padding[2]] = 0
|
||||
crop_tensor[:, :, :, 112 - padding[3]:] = 0
|
||||
|
||||
embedding = embedder(crop_tensor)
|
||||
embedding = nn.functional.normalize(embedding, p = 2)
|
||||
return embedding
|
||||
|
||||
|
||||
def overlay_mask(input_tensor : Tensor, input_mask : Mask) -> Tensor:
|
||||
overlay_tensor = torch.zeros(*input_tensor.shape, dtype = input_tensor.dtype, device = input_tensor.device)
|
||||
overlay_tensor[:, 2, :, :] = 1
|
||||
input_mask = input_mask.repeat(1, 3, 1, 1).clamp(0, 0.8)
|
||||
output_tensor = input_tensor * (1 - input_mask) + overlay_tensor * input_mask
|
||||
return output_tensor
|
||||
@@ -0,0 +1,27 @@
|
||||
import configparser
|
||||
|
||||
import torch
|
||||
from torchvision import io
|
||||
|
||||
from .helper import calc_embedding
|
||||
from .training import FaceSwapperTrainer
|
||||
|
||||
CONFIG_PARSER = configparser.ConfigParser()
|
||||
CONFIG_PARSER.read('config.ini')
|
||||
|
||||
|
||||
def infer() -> None:
|
||||
config_generator_path = CONFIG_PARSER.get('inferencing', 'generator_path')
|
||||
config_embedder_path = CONFIG_PARSER.get('inferencing', 'embedder_path')
|
||||
config_source_path = CONFIG_PARSER.get('inferencing', 'source_path')
|
||||
config_target_path = CONFIG_PARSER.get('inferencing', 'target_path')
|
||||
config_output_path = CONFIG_PARSER.get('inferencing', 'output_path')
|
||||
|
||||
generator = FaceSwapperTrainer.load_from_checkpoint(config_generator_path, config_parser = CONFIG_PARSER, map_location = 'cpu').eval()
|
||||
embedder = torch.jit.load(config_embedder_path, map_location = 'cpu').eval()
|
||||
|
||||
source_tensor = io.read_image(config_source_path)
|
||||
target_tensor = io.read_image(config_target_path)
|
||||
source_embedding = calc_embedding(embedder, source_tensor, (0, 0, 0, 0))
|
||||
output_tensor, _ = generator(source_embedding, target_tensor)
|
||||
io.write_jpeg(output_tensor, config_output_path)
|
||||
@@ -0,0 +1,35 @@
|
||||
from configparser import ConfigParser
|
||||
from typing import List
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..networks.nld import NLD
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config_num_discriminators = config_parser.getint('training.model.discriminator', 'num_discriminators')
|
||||
self.config_parser = config_parser
|
||||
self.discriminators = self.create_discriminators()
|
||||
self.avg_pool = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = (1, 1), count_include_pad = False)
|
||||
|
||||
def create_discriminators(self) -> nn.ModuleList:
|
||||
discriminators = nn.ModuleList()
|
||||
|
||||
for _ in range(self.config_num_discriminators):
|
||||
discriminator = NLD(self.config_parser).sequences
|
||||
discriminators.append(discriminator)
|
||||
|
||||
return discriminators
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> List[Tensor]:
|
||||
temp_tensor = input_tensor
|
||||
output_tensors = []
|
||||
|
||||
for discriminator in self.discriminators:
|
||||
output_tensor = discriminator(temp_tensor)
|
||||
output_tensors.append(output_tensor)
|
||||
temp_tensor = self.avg_pool(temp_tensor)
|
||||
|
||||
return output_tensors
|
||||
@@ -0,0 +1,42 @@
|
||||
from configparser import ConfigParser
|
||||
from typing import Tuple
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..networks.aad import AAD
|
||||
from ..networks.masknet import MaskNet
|
||||
from ..networks.unet import UNet
|
||||
from ..types import Embedding, Feature, Mask
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.encoder = UNet(config_parser)
|
||||
self.generator = AAD(config_parser)
|
||||
self.masker = MaskNet(config_parser)
|
||||
self.encoder.apply(init_weight)
|
||||
self.generator.apply(init_weight)
|
||||
self.masker.apply(init_weight)
|
||||
|
||||
def forward(self, source_embedding : Embedding, target_tensor : Tensor, target_features : Tuple[Feature, ...]) -> Tuple[Tensor, Mask]:
|
||||
output_tensor = self.generator(source_embedding, target_features)
|
||||
target_feature = target_features[-1]
|
||||
output_mask = self.masker(target_tensor, target_feature)
|
||||
output_tensor = output_tensor * output_mask + target_tensor * (1 - output_mask)
|
||||
return output_tensor, output_mask
|
||||
|
||||
def encode_features(self, input_tensor : Tensor) -> Tuple[Feature, ...]:
|
||||
return self.encoder(input_tensor)
|
||||
|
||||
|
||||
def init_weight(module : nn.Module) -> None:
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(std = 0.001)
|
||||
module.bias.data.zero_()
|
||||
|
||||
if isinstance(module, nn.Conv2d):
|
||||
nn.init.xavier_normal_(module.weight.data)
|
||||
|
||||
if isinstance(module, nn.ConvTranspose2d):
|
||||
nn.init.xavier_normal_(module.weight.data)
|
||||
@@ -0,0 +1,186 @@
|
||||
from configparser import ConfigParser
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from pytorch_msssim import ssim
|
||||
from torch import Tensor, nn
|
||||
from torchvision import transforms
|
||||
|
||||
from ..helper import calc_embedding
|
||||
from ..types import EmbedderModule, FaceMaskerModule, Feature, GazerModule, Loss, Mask
|
||||
|
||||
|
||||
class DiscriminatorLoss(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, discriminator_source_tensors : List[Tensor], discriminator_output_tensors : List[Tensor]) -> Loss:
|
||||
positive_tensors = []
|
||||
negative_tensors = []
|
||||
|
||||
for discriminator_source_tensor in discriminator_source_tensors:
|
||||
positive_tensor = torch.relu(1 - discriminator_source_tensor).mean(dim = [ 1, 2, 3 ])
|
||||
positive_tensors.append(positive_tensor)
|
||||
|
||||
for discriminator_output_tensor in discriminator_output_tensors:
|
||||
negative_tensor = torch.relu(discriminator_output_tensor + 1).mean(dim = [ 1, 2, 3 ])
|
||||
negative_tensors.append(negative_tensor)
|
||||
|
||||
positive_loss = torch.stack(positive_tensors).mean()
|
||||
negative_loss = torch.stack(negative_tensors).mean()
|
||||
discriminator_loss = (positive_loss + negative_loss) * 0.5
|
||||
return discriminator_loss
|
||||
|
||||
|
||||
class AdversarialLoss(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config_adversarial_weight = config_parser.getfloat('training.losses', 'adversarial_weight')
|
||||
|
||||
def forward(self, discriminator_output_tensors : List[Tensor]) -> Tuple[Loss, Loss]:
|
||||
temp_tensors = []
|
||||
|
||||
for discriminator_output_tensor in discriminator_output_tensors:
|
||||
temp_tensor = torch.relu(1 - discriminator_output_tensor).mean(dim = [ 1, 2, 3 ]).mean()
|
||||
temp_tensors.append(temp_tensor)
|
||||
|
||||
adversarial_loss = torch.stack(temp_tensors).mean()
|
||||
weighted_adversarial_loss = adversarial_loss * self.config_adversarial_weight
|
||||
return adversarial_loss, weighted_adversarial_loss
|
||||
|
||||
|
||||
class CycleLoss(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config_batch_size = config_parser.getint('training.loader', 'batch_size')
|
||||
self.config_cycle_weight = config_parser.getfloat('training.losses', 'cycle_weight')
|
||||
self.l1_loss = nn.L1Loss()
|
||||
|
||||
def forward(self, target_tensor : Tensor, cycle_tensor : Tensor, target_features : Tuple[Feature, ...], cycle_features : Tuple[Feature, ...]) -> Tuple[Loss, Loss]:
|
||||
temp_tensors = []
|
||||
|
||||
for target_feature, output_feature in zip(target_features, cycle_features):
|
||||
temp_tensor = torch.mean(torch.pow(output_feature - target_feature, 2).reshape(self.config_batch_size, -1), dim = 1).mean()
|
||||
temp_tensors.append(temp_tensor)
|
||||
|
||||
feature_loss = torch.stack(temp_tensors).mean()
|
||||
reconstruction_loss = self.l1_loss(target_tensor, cycle_tensor)
|
||||
cycle_loss = (feature_loss + reconstruction_loss) * 0.5
|
||||
weighted_feature_loss = cycle_loss * self.config_cycle_weight
|
||||
return cycle_loss, weighted_feature_loss
|
||||
|
||||
|
||||
class FeatureLoss(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config_batch_size = config_parser.getint('training.loader', 'batch_size')
|
||||
self.config_feature_weight = config_parser.getfloat('training.losses', 'feature_weight')
|
||||
|
||||
def forward(self, target_features : Tuple[Feature, ...], output_features : Tuple[Feature, ...]) -> Tuple[Loss, Loss]:
|
||||
temp_tensors = []
|
||||
|
||||
for target_feature, output_feature in zip(target_features, output_features):
|
||||
temp_tensor = torch.mean(torch.pow(output_feature - target_feature, 2).reshape(self.config_batch_size, -1), dim = 1).mean()
|
||||
temp_tensors.append(temp_tensor)
|
||||
|
||||
feature_loss = torch.stack(temp_tensors).mean() * 0.5
|
||||
weighted_feature_loss = feature_loss * self.config_feature_weight
|
||||
return feature_loss, weighted_feature_loss
|
||||
|
||||
|
||||
class ReconstructionLoss(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser, embedder : EmbedderModule) -> None:
|
||||
super().__init__()
|
||||
self.config_reconstruction_weight = config_parser.getfloat('training.losses', 'reconstruction_weight')
|
||||
self.embedder = embedder
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def forward(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss]:
|
||||
with torch.no_grad():
|
||||
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
|
||||
target_embedding = calc_embedding(self.embedder, target_tensor, (0, 0, 0, 0))
|
||||
|
||||
has_similar_identity = torch.cosine_similarity(source_embedding, target_embedding) > 0.8
|
||||
|
||||
reconstruction_loss = torch.mean((source_tensor - target_tensor) ** 2, dim = (1, 2, 3))
|
||||
reconstruction_loss = (reconstruction_loss * has_similar_identity).mean() * 0.5
|
||||
|
||||
data_range = float(torch.max(output_tensor) - torch.min(output_tensor))
|
||||
visual_loss = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean()
|
||||
reconstruction_loss = (reconstruction_loss + visual_loss) * 0.5
|
||||
weighted_reconstruction_loss = reconstruction_loss * self.config_reconstruction_weight
|
||||
return reconstruction_loss, weighted_reconstruction_loss
|
||||
|
||||
|
||||
class IdentityLoss(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser, embedder : EmbedderModule) -> None:
|
||||
super().__init__()
|
||||
self.config_identity_weight = config_parser.getfloat('training.losses', 'identity_weight')
|
||||
self.embedder = embedder
|
||||
|
||||
def forward(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss]:
|
||||
output_embedding = calc_embedding(self.embedder, output_tensor, (30, 0, 10, 10))
|
||||
source_embedding = calc_embedding(self.embedder, source_tensor, (30, 0, 10, 10))
|
||||
identity_loss = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean()
|
||||
weighted_identity_loss = identity_loss * self.config_identity_weight
|
||||
return identity_loss, weighted_identity_loss
|
||||
|
||||
|
||||
class GazeLoss(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser, gazer : GazerModule) -> None:
|
||||
super().__init__()
|
||||
self.config_gaze_weight = config_parser.getfloat('training.losses', 'gaze_weight')
|
||||
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
|
||||
self.gazer = gazer
|
||||
self.l1_loss = nn.L1Loss()
|
||||
|
||||
def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss]:
|
||||
output_pitch, output_yaw = self.detect_gaze(output_tensor)
|
||||
target_pitch, target_yaw = self.detect_gaze(target_tensor)
|
||||
|
||||
pitch_loss = self.l1_loss(output_pitch, target_pitch)
|
||||
yaw_loss = self.l1_loss(output_yaw, target_yaw)
|
||||
|
||||
gaze_loss = (pitch_loss + yaw_loss) * 0.5
|
||||
weighted_gaze_loss = gaze_loss * self.config_gaze_weight
|
||||
return gaze_loss, weighted_gaze_loss
|
||||
|
||||
def detect_gaze(self, input_tensor : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
crop_sizes = (torch.tensor([ 0.235, 0.875, 0.0625, 0.8 ]) * self.config_output_size).int()
|
||||
crop_tensor = input_tensor[:, :, crop_sizes[0]:crop_sizes[1], crop_sizes[2]:crop_sizes[3]]
|
||||
crop_tensor = (crop_tensor + 1) * 0.5
|
||||
crop_tensor = transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ])(crop_tensor)
|
||||
crop_tensor = nn.functional.interpolate(crop_tensor, size = 448, mode = 'bicubic')
|
||||
|
||||
with torch.no_grad():
|
||||
pitch, yaw = self.gazer(crop_tensor)
|
||||
|
||||
return pitch, yaw
|
||||
|
||||
|
||||
class MaskLoss(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser, face_masker : FaceMaskerModule) -> None:
|
||||
super().__init__()
|
||||
self.config_mask_weight = config_parser.getfloat('training.losses', 'mask_weight')
|
||||
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
|
||||
self.face_masker = face_masker
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def forward(self, target_tensor : Tensor, output_mask : Mask) -> Tuple[Loss, Loss]:
|
||||
target_mask = self.calc_mask(target_tensor)
|
||||
target_mask = target_mask.view(-1, self.config_output_size, self.config_output_size)
|
||||
output_mask = output_mask.view(-1, self.config_output_size, self.config_output_size)
|
||||
mask_loss = self.mse_loss(target_mask, output_mask)
|
||||
weighted_mask_loss = mask_loss * self.config_mask_weight
|
||||
return mask_loss, weighted_mask_loss
|
||||
|
||||
def calc_mask(self, target_tensor : Tensor) -> Tensor:
|
||||
target_tensor = torch.nn.functional.interpolate(target_tensor, (256, 256), mode = 'bilinear')
|
||||
target_tensor = (target_tensor.clip(-1, 1) + 1) * 0.5
|
||||
|
||||
with torch.no_grad():
|
||||
output_tensor = self.face_masker(target_tensor)
|
||||
output_tensor = output_tensor.clamp(0, 1)
|
||||
output_tensor = torch.nn.functional.interpolate(output_tensor, (self.config_output_size, self.config_output_size), mode = 'bilinear')
|
||||
|
||||
return output_tensor
|
||||
@@ -0,0 +1,191 @@
|
||||
from configparser import ConfigParser
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..types import Embedding, Feature
|
||||
|
||||
|
||||
class AAD(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config_source_channels = config_parser.getint('training.model.generator', 'source_channels')
|
||||
self.config_output_channels = config_parser.getint('training.model.generator', 'output_channels')
|
||||
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
|
||||
self.config_num_blocks = config_parser.getint('training.model.generator', 'num_blocks')
|
||||
self.pixel_shuffle_up_sample = PixelShuffleUpSample(self.config_source_channels, self.config_output_channels)
|
||||
self.layers = self.create_layers()
|
||||
|
||||
def create_layers(self) -> nn.ModuleList:
|
||||
layers = nn.ModuleList()
|
||||
|
||||
if self.config_output_size == 128:
|
||||
layers.extend(
|
||||
[
|
||||
AdaptiveFeatureModulation(512, 512, 512, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(512, 512, 1024, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(512, 512, 512, self.config_source_channels, self.config_num_blocks)
|
||||
])
|
||||
|
||||
if self.config_output_size == 256:
|
||||
layers.extend(
|
||||
[
|
||||
AdaptiveFeatureModulation(1024, 1024, 1024, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 1024, 2048, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 1024, 1024, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks)
|
||||
])
|
||||
|
||||
if self.config_output_size == 512:
|
||||
layers.extend(
|
||||
[
|
||||
AdaptiveFeatureModulation(2048, 2048, 2048, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(2048, 2048, 4096, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(2048, 2048, 2048, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(2048, 1024, 1024, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks)
|
||||
])
|
||||
|
||||
if self.config_output_size == 1024:
|
||||
layers.extend(
|
||||
[
|
||||
AdaptiveFeatureModulation(4096, 4096, 4096, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(4096, 4096, 8192, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(4096, 4096, 4096, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(4096, 2048, 2048, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(2048, 1024, 1024, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks)
|
||||
])
|
||||
|
||||
layers.extend(
|
||||
[
|
||||
AdaptiveFeatureModulation(512, 256, 256, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(256, 128, 128, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(128, 64, 64, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(64, 3, 64, self.config_source_channels, self.config_num_blocks)
|
||||
])
|
||||
|
||||
return layers
|
||||
|
||||
def forward(self, source_embedding : Embedding, target_features : Tuple[Feature, ...]) -> Tensor:
|
||||
temp_tensors = self.pixel_shuffle_up_sample(source_embedding)
|
||||
|
||||
for index, layer in enumerate(self.layers[:-1]):
|
||||
target_feature = target_features[index]
|
||||
temp_tensor = layer(temp_tensors, source_embedding, target_feature)
|
||||
temp_tensors = nn.functional.interpolate(temp_tensor, scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
|
||||
target_feature = target_features[-1]
|
||||
temp_tensors = self.layers[-1](temp_tensors, source_embedding, target_feature)
|
||||
output_tensor = torch.tanh(temp_tensors)
|
||||
return output_tensor
|
||||
|
||||
|
||||
class AdaptiveFeatureModulation(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int, target_channels : int, source_channels : int, num_blocks : int) -> None:
|
||||
super().__init__()
|
||||
self.context_input_channels = input_channels
|
||||
self.context_output_channels = output_channels
|
||||
self.context_target_channels = target_channels
|
||||
self.context_source_channels = source_channels
|
||||
self.context_num_blocks = num_blocks
|
||||
self.primary_layers = self.create_primary_layers()
|
||||
self.shortcut_layers = self.create_shortcut_layers()
|
||||
|
||||
def create_primary_layers(self) -> nn.ModuleList:
|
||||
primary_layers = nn.ModuleList()
|
||||
|
||||
for index in range(self.context_num_blocks):
|
||||
primary_layers.extend(
|
||||
[
|
||||
FeatureModulation(self.context_input_channels, self.context_target_channels, self.context_source_channels),
|
||||
nn.ReLU(inplace = True)
|
||||
])
|
||||
|
||||
if index < self.context_num_blocks - 1:
|
||||
primary_layers.append(nn.Conv2d(self.context_input_channels, self.context_input_channels, kernel_size = 3, padding = 1, bias = False))
|
||||
else:
|
||||
primary_layers.append(nn.Conv2d(self.context_input_channels, self.context_output_channels, kernel_size = 3, padding = 1, bias = False))
|
||||
|
||||
return primary_layers
|
||||
|
||||
def create_shortcut_layers(self) -> nn.ModuleList:
|
||||
shortcut_layers = nn.ModuleList()
|
||||
|
||||
if self.context_input_channels > self.context_output_channels:
|
||||
shortcut_layers.extend(
|
||||
[
|
||||
FeatureModulation(self.context_input_channels, self.context_target_channels, self.context_source_channels),
|
||||
nn.ReLU(inplace = True),
|
||||
nn.Conv2d(self.context_input_channels, self.context_output_channels, kernel_size = 3, padding = 1, bias = False)
|
||||
])
|
||||
|
||||
return shortcut_layers
|
||||
|
||||
def forward(self, input_tensor : Tensor, source_embedding : Embedding, target_feature : Feature) -> Tensor:
|
||||
primary_tensor = input_tensor
|
||||
|
||||
for primary_layer in self.primary_layers:
|
||||
if isinstance(primary_layer, FeatureModulation):
|
||||
primary_tensor = primary_layer(primary_tensor, source_embedding, target_feature)
|
||||
else:
|
||||
primary_tensor = primary_layer(primary_tensor)
|
||||
|
||||
if self.context_input_channels > self.context_output_channels:
|
||||
shortcut_tensor = input_tensor
|
||||
|
||||
for shortcut_layer in self.shortcut_layers:
|
||||
if isinstance(shortcut_layer, FeatureModulation):
|
||||
shortcut_tensor = shortcut_layer(shortcut_tensor, source_embedding, target_feature)
|
||||
else:
|
||||
shortcut_tensor = shortcut_layer(shortcut_tensor)
|
||||
|
||||
input_tensor = shortcut_tensor
|
||||
|
||||
return primary_tensor + input_tensor
|
||||
|
||||
|
||||
class FeatureModulation(nn.Module):
|
||||
def __init__(self, input_channels : int, target_channels : int, source_channels : int) -> None:
|
||||
super().__init__()
|
||||
self.context_input_channels = input_channels
|
||||
self.conv1 = nn.Conv2d(target_channels, input_channels, kernel_size = 1)
|
||||
self.conv2 = nn.Conv2d(target_channels, input_channels, kernel_size = 1)
|
||||
self.conv3 = nn.Conv2d(input_channels, 1, kernel_size = 1)
|
||||
self.linear1 = nn.Linear(source_channels, input_channels)
|
||||
self.linear2 = nn.Linear(source_channels, input_channels)
|
||||
self.instance_norm = nn.InstanceNorm2d(input_channels)
|
||||
|
||||
def forward(self, input_tensor : Tensor, source_embedding : Embedding, target_feature : Feature) -> Tensor:
|
||||
temp_tensor = self.instance_norm(input_tensor)
|
||||
|
||||
source_scale = self.linear2(source_embedding).reshape(temp_tensor.shape[0], self.context_input_channels, 1, 1).expand_as(temp_tensor)
|
||||
source_shift = self.linear1(source_embedding).reshape(temp_tensor.shape[0], self.context_input_channels, 1, 1).expand_as(temp_tensor)
|
||||
source_modulation = source_scale * temp_tensor + source_shift
|
||||
|
||||
target_scale = self.conv1(target_feature)
|
||||
target_shift = self.conv2(target_feature)
|
||||
target_modulation = target_scale * temp_tensor + target_shift
|
||||
|
||||
temp_mask = torch.sigmoid(self.conv3(temp_tensor))
|
||||
output_tensor = (1 - temp_mask) * target_modulation + temp_mask * source_modulation
|
||||
return output_tensor
|
||||
|
||||
|
||||
class PixelShuffleUpSample(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int) -> None:
|
||||
super().__init__()
|
||||
self.sequences = self.create_sequences(input_channels, output_channels)
|
||||
|
||||
@staticmethod
|
||||
def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential:
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1),
|
||||
nn.PixelShuffle(upscale_factor = 2)
|
||||
)
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> Tensor:
|
||||
temp_tensor = input_tensor.view(input_tensor.shape[0], -1, 1, 1)
|
||||
output_tensor = self.sequences(temp_tensor)
|
||||
return output_tensor
|
||||
@@ -0,0 +1,111 @@
|
||||
from configparser import ConfigParser
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..types import Feature, Mask
|
||||
|
||||
|
||||
class MaskNet(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config_input_channels = config_parser.getint('training.model.masker', 'input_channels')
|
||||
self.config_output_channels = config_parser.getint('training.model.masker', 'output_channels')
|
||||
self.config_num_filters = config_parser.getint('training.model.masker', 'num_filters')
|
||||
self.down_samples = self.create_down_samples(self.config_input_channels, self.config_num_filters)
|
||||
self.up_samples = self.create_up_samples(self.config_num_filters)
|
||||
self.bottleneck = BottleNeck(self.config_num_filters * 4)
|
||||
self.conv = nn.Conv2d(self.config_num_filters, self.config_output_channels, kernel_size = 1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
@staticmethod
|
||||
def create_down_samples(input_channels : int, num_filters : int) -> nn.ModuleList:
|
||||
return nn.ModuleList(
|
||||
[
|
||||
DownSample(input_channels, num_filters),
|
||||
DownSample(num_filters, num_filters * 2),
|
||||
DownSample(num_filters * 2, num_filters * 4)
|
||||
])
|
||||
|
||||
@staticmethod
|
||||
def create_up_samples(num_filters : int) -> nn.ModuleList:
|
||||
return nn.ModuleList(
|
||||
[
|
||||
UpSample(num_filters * 4, num_filters * 2),
|
||||
UpSample(num_filters * 2, num_filters),
|
||||
UpSample(num_filters, num_filters)
|
||||
])
|
||||
|
||||
def forward(self, input_tensor : Tensor, input_feature : Feature) -> Mask:
|
||||
output_mask = torch.cat([ input_tensor, input_feature ], dim = 1)
|
||||
|
||||
for down_sample in self.down_samples:
|
||||
output_mask = down_sample(output_mask)
|
||||
|
||||
output_mask = self.bottleneck(output_mask)
|
||||
|
||||
for up_sample in self.up_samples:
|
||||
output_mask = up_sample(output_mask)
|
||||
|
||||
output_mask = self.conv(output_mask)
|
||||
output_mask = self.sigmoid(output_mask)
|
||||
return output_mask
|
||||
|
||||
|
||||
class BottleNeck(nn.Module):
|
||||
def __init__(self, num_filters : int):
|
||||
super().__init__()
|
||||
self.sequences = self.create_sequences(num_filters)
|
||||
self.relu = nn.ReLU(inplace = True)
|
||||
|
||||
@staticmethod
|
||||
def create_sequences(num_filters : int) -> nn.Sequential:
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(num_filters, num_filters, kernel_size = 3, padding = 1, bias = False),
|
||||
nn.BatchNorm2d(num_filters),
|
||||
nn.ReLU(inplace = True),
|
||||
nn.Conv2d(num_filters, num_filters, kernel_size = 3, padding = 1, bias = False),
|
||||
nn.BatchNorm2d(num_filters),
|
||||
nn.ReLU(inplace = True)
|
||||
)
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> Tensor:
|
||||
output_tensor = self.sequences(input_tensor) + input_tensor
|
||||
output_tensor = self.relu(output_tensor)
|
||||
return output_tensor
|
||||
|
||||
|
||||
class UpSample(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int) -> None:
|
||||
super().__init__()
|
||||
self.sequences = self.create_sequences(input_channels, output_channels)
|
||||
|
||||
@staticmethod
|
||||
def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential:
|
||||
return nn.Sequential(
|
||||
nn.ConvTranspose2d(input_channels, output_channels, kernel_size = 2, stride = 2),
|
||||
nn.ReLU(inplace = True)
|
||||
)
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> Tensor:
|
||||
output_tensor = self.sequences(input_tensor)
|
||||
return output_tensor
|
||||
|
||||
|
||||
class DownSample(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int) -> None:
|
||||
super().__init__()
|
||||
self.sequences = self.create_sequences(input_channels, output_channels)
|
||||
|
||||
@staticmethod
|
||||
def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential:
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1, bias = False),
|
||||
nn.BatchNorm2d(output_channels),
|
||||
nn.ReLU(inplace = True),
|
||||
nn.MaxPool2d(2)
|
||||
)
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> Tensor:
|
||||
output_tensor = self.sequences(input_tensor)
|
||||
return output_tensor
|
||||
@@ -0,0 +1,48 @@
|
||||
import math
|
||||
from configparser import ConfigParser
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class NLD(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config_input_channels = config_parser.getint('training.model.discriminator', 'input_channels')
|
||||
self.config_num_filters = config_parser.getint('training.model.discriminator', 'num_filters')
|
||||
self.config_kernel_size = config_parser.getint('training.model.discriminator', 'kernel_size')
|
||||
self.config_num_layers = config_parser.getint('training.model.discriminator', 'num_layers')
|
||||
self.layers = self.create_layers()
|
||||
self.sequences = nn.Sequential(*self.layers)
|
||||
|
||||
def create_layers(self) -> nn.ModuleList:
|
||||
padding = math.ceil((self.config_kernel_size - 1) / 2)
|
||||
current_filters = self.config_num_filters
|
||||
layers = nn.ModuleList(
|
||||
[
|
||||
nn.Conv2d(self.config_input_channels, current_filters, kernel_size = self.config_kernel_size, stride = 2, padding = padding),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
])
|
||||
|
||||
for _ in range(1, self.config_num_layers):
|
||||
previous_filters = current_filters
|
||||
current_filters = min(current_filters * 2, 512)
|
||||
layers +=\
|
||||
[
|
||||
nn.Conv2d(previous_filters, current_filters, kernel_size = self.config_kernel_size, stride = 2, padding = padding),
|
||||
nn.InstanceNorm2d(current_filters),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
previous_filters = current_filters
|
||||
current_filters = min(current_filters * 2, 512)
|
||||
layers +=\
|
||||
[
|
||||
nn.Conv2d(previous_filters, current_filters, kernel_size = self.config_kernel_size, padding = padding),
|
||||
nn.InstanceNorm2d(current_filters),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(current_filters, 1, kernel_size = self.config_kernel_size, padding = padding)
|
||||
]
|
||||
return layers
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> Tensor:
|
||||
return self.sequences(input_tensor)
|
||||
@@ -0,0 +1,157 @@
|
||||
from configparser import ConfigParser
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..types import Feature
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
|
||||
self.down_samples = self.create_down_samples()
|
||||
self.up_samples = self.create_up_samples()
|
||||
|
||||
def create_down_samples(self) -> nn.ModuleList:
|
||||
down_samples = nn.ModuleList(
|
||||
[
|
||||
DownSample(3, 32),
|
||||
DownSample(32, 64),
|
||||
DownSample(64, 128),
|
||||
DownSample(128, 256),
|
||||
DownSample(256, 512)
|
||||
])
|
||||
|
||||
if self.config_output_size == 128:
|
||||
down_samples.extend(
|
||||
[
|
||||
DownSample(512, 512)
|
||||
])
|
||||
|
||||
if self.config_output_size == 256:
|
||||
down_samples.extend(
|
||||
[
|
||||
DownSample(512, 1024),
|
||||
DownSample(1024, 1024)
|
||||
])
|
||||
|
||||
if self.config_output_size == 512:
|
||||
down_samples.extend(
|
||||
[
|
||||
DownSample(512, 1024),
|
||||
DownSample(1024, 2048),
|
||||
DownSample(2048, 2048)
|
||||
])
|
||||
|
||||
if self.config_output_size == 1024:
|
||||
down_samples.extend(
|
||||
[
|
||||
DownSample(512, 1024),
|
||||
DownSample(1024, 2048),
|
||||
DownSample(2048, 4096),
|
||||
DownSample(4096, 4096)
|
||||
])
|
||||
|
||||
return down_samples
|
||||
|
||||
def create_up_samples(self) -> nn.ModuleList:
|
||||
up_samples = nn.ModuleList()
|
||||
|
||||
if self.config_output_size == 128:
|
||||
up_samples.extend(
|
||||
[
|
||||
UpSample(512, 512)
|
||||
])
|
||||
|
||||
if self.config_output_size == 256:
|
||||
up_samples.extend(
|
||||
[
|
||||
UpSample(1024, 1024),
|
||||
UpSample(2048, 512)
|
||||
])
|
||||
|
||||
if self.config_output_size == 512:
|
||||
up_samples.extend(
|
||||
[
|
||||
UpSample(2048, 2048),
|
||||
UpSample(4096, 1024),
|
||||
UpSample(2048, 512)
|
||||
])
|
||||
|
||||
if self.config_output_size == 1024:
|
||||
up_samples.extend(
|
||||
[
|
||||
UpSample(4096, 4096),
|
||||
UpSample(8192, 2048),
|
||||
UpSample(4096, 1024),
|
||||
UpSample(2048, 512)
|
||||
])
|
||||
|
||||
up_samples.extend(
|
||||
[
|
||||
UpSample(1024, 256),
|
||||
UpSample(512, 128),
|
||||
UpSample(256, 64),
|
||||
UpSample(128, 32)
|
||||
])
|
||||
|
||||
return up_samples
|
||||
|
||||
def forward(self, target_tensor : Tensor) -> Tuple[Feature, ...]:
|
||||
down_features = []
|
||||
up_features = []
|
||||
temp_feature = target_tensor
|
||||
|
||||
for down_sample in self.down_samples:
|
||||
temp_feature = down_sample(temp_feature)
|
||||
down_features.append(temp_feature)
|
||||
|
||||
bottleneck_feature = down_features[-1]
|
||||
temp_feature = bottleneck_feature
|
||||
|
||||
for index, up_sample in enumerate(self.up_samples):
|
||||
skip_tensor = down_features[-(index + 2)]
|
||||
temp_feature = up_sample(temp_feature, skip_tensor)
|
||||
up_features.append(temp_feature)
|
||||
|
||||
final_feature = nn.functional.interpolate(temp_feature, scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
return bottleneck_feature, *up_features, final_feature
|
||||
|
||||
|
||||
class UpSample(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int) -> None:
|
||||
super().__init__()
|
||||
self.sequences = self.create_sequences(input_channels, output_channels)
|
||||
|
||||
@staticmethod
|
||||
def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential:
|
||||
return nn.Sequential(
|
||||
nn.ConvTranspose2d(input_channels, output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False),
|
||||
nn.BatchNorm2d(output_channels),
|
||||
nn.LeakyReLU(0.1, inplace = True)
|
||||
)
|
||||
|
||||
def forward(self, input_tensor : Tensor, skip_tensor : Tensor) -> Tensor:
|
||||
output_tensor = self.sequences(input_tensor)
|
||||
output_tensor = torch.cat((output_tensor, skip_tensor), dim = 1)
|
||||
return output_tensor
|
||||
|
||||
|
||||
class DownSample(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int) -> None:
|
||||
super().__init__()
|
||||
self.sequences = self.create_sequences(input_channels, output_channels)
|
||||
|
||||
@staticmethod
|
||||
def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential:
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(input_channels, output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False),
|
||||
nn.BatchNorm2d(output_channels),
|
||||
nn.LeakyReLU(0.1, inplace = True)
|
||||
)
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> Tensor:
|
||||
output_tensor = self.sequences(input_tensor)
|
||||
return output_tensor
|
||||
@@ -0,0 +1,248 @@
|
||||
import os
|
||||
import warnings
|
||||
from configparser import ConfigParser
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from lightning import LightningModule, Trainer
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
from torch import Tensor, nn
|
||||
from torch.utils.data import ConcatDataset, Dataset, random_split
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from .dataset import DynamicDataset
|
||||
from .helper import calc_embedding, overlay_mask
|
||||
from .models.discriminator import Discriminator
|
||||
from .models.generator import Generator
|
||||
from .models.loss import AdversarialLoss, CycleLoss, DiscriminatorLoss, FeatureLoss, GazeLoss, IdentityLoss, MaskLoss, ReconstructionLoss
|
||||
from .types import Batch, Embedding, Mask, OptimizerSet
|
||||
|
||||
warnings.filterwarnings('ignore', category = UserWarning, module = 'torch')
|
||||
|
||||
CONFIG_PARSER = ConfigParser()
|
||||
CONFIG_PARSER.read('config.ini')
|
||||
|
||||
|
||||
class FaceSwapperTrainer(LightningModule):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config_generator_embedder_path = config_parser.get('training.model', 'generator_embedder_path')
|
||||
self.config_loss_embedder_path = config_parser.get('training.model', 'loss_embedder_path')
|
||||
self.config_gazer_path = config_parser.get('training.model', 'gazer_path')
|
||||
self.config_face_masker_path = config_parser.get('training.model', 'face_masker_path')
|
||||
self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size')
|
||||
self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate')
|
||||
self.config_preview_frequency = config_parser.getint('training.trainer', 'preview_frequency')
|
||||
self.generator_embedder = torch.jit.load(self.config_generator_embedder_path, map_location = 'cpu').eval()
|
||||
self.loss_embedder = torch.jit.load(self.config_loss_embedder_path, map_location = 'cpu').eval()
|
||||
self.gazer = torch.jit.load(self.config_gazer_path, map_location = 'cpu').eval()
|
||||
self.face_masker = torch.jit.load(self.config_face_masker_path, map_location ='cpu').eval()
|
||||
self.generator = Generator(config_parser)
|
||||
self.discriminator = Discriminator(config_parser)
|
||||
self.discriminator_loss = DiscriminatorLoss()
|
||||
self.adversarial_loss = AdversarialLoss(config_parser)
|
||||
self.cycle_loss = CycleLoss(config_parser)
|
||||
self.feature_loss = FeatureLoss(config_parser)
|
||||
self.reconstruction_loss = ReconstructionLoss(config_parser, self.loss_embedder)
|
||||
self.identity_loss = IdentityLoss(config_parser, self.loss_embedder)
|
||||
self.gaze_loss = GazeLoss(config_parser, self.gazer)
|
||||
self.mask_loss = MaskLoss(config_parser, self.face_masker)
|
||||
self.automatic_optimization = False
|
||||
|
||||
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Mask]:
|
||||
with torch.no_grad():
|
||||
generator_target_features = self.generator.encode_features(target_tensor)
|
||||
output_tensor, output_mask = self.generator(source_embedding, target_tensor, generator_target_features)
|
||||
|
||||
return output_tensor, output_mask
|
||||
|
||||
def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet]:
|
||||
generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
generator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(generator_optimizer, T_0 = 300, T_mult = 2)
|
||||
discriminator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(discriminator_optimizer, T_0 = 300, T_mult = 2)
|
||||
|
||||
generator_config =\
|
||||
{
|
||||
'optimizer': generator_optimizer,
|
||||
'lr_scheduler':
|
||||
{
|
||||
'scheduler': generator_scheduler,
|
||||
'interval': 'step'
|
||||
}
|
||||
}
|
||||
discriminator_config =\
|
||||
{
|
||||
'optimizer': discriminator_optimizer,
|
||||
'lr_scheduler':
|
||||
{
|
||||
'scheduler': discriminator_scheduler,
|
||||
'interval': 'step'
|
||||
}
|
||||
}
|
||||
return generator_config, discriminator_config
|
||||
|
||||
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
source_tensor, target_tensor = batch
|
||||
do_update = (batch_index + 1) % self.config_accumulate_size == 0
|
||||
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
|
||||
|
||||
source_embedding = calc_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0))
|
||||
target_embedding = calc_embedding(self.generator_embedder, target_tensor, (0, 0, 0, 0))
|
||||
generator_target_features = self.generator.encode_features(target_tensor)
|
||||
generator_output_tensor, generator_output_mask = self.generator(source_embedding, target_tensor, generator_target_features)
|
||||
generator_output_features = self.generator.encode_features(generator_output_tensor)
|
||||
cycle_output_tensor, cycle_output_mask = self.generator(target_embedding, generator_output_tensor, generator_output_features)
|
||||
cycle_output_features = self.generator.encode_features(cycle_output_tensor)
|
||||
discriminator_output_tensors = self.discriminator(generator_output_tensor)
|
||||
adversarial_loss, weighted_adversarial_loss = self.adversarial_loss(discriminator_output_tensors)
|
||||
cycle_loss, weighted_cycle_loss = self.cycle_loss(target_tensor, cycle_output_tensor, generator_target_features, cycle_output_features)
|
||||
feature_loss, weighted_feature_loss = self.feature_loss(generator_target_features, generator_output_features)
|
||||
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss(source_tensor, target_tensor, generator_output_tensor)
|
||||
identity_loss, weighted_identity_loss = self.identity_loss(generator_output_tensor, source_tensor)
|
||||
gaze_loss, weighted_gaze_loss = self.gaze_loss(target_tensor, generator_output_tensor)
|
||||
mask_loss, weighted_mask_loss = self.mask_loss(target_tensor, generator_output_mask)
|
||||
generator_loss = weighted_adversarial_loss + weighted_cycle_loss + weighted_feature_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_gaze_loss + weighted_mask_loss
|
||||
|
||||
discriminator_source_tensors = self.discriminator(source_tensor)
|
||||
discriminator_output_tensors = self.discriminator(generator_output_tensor.detach())
|
||||
discriminator_loss = self.discriminator_loss(discriminator_source_tensors, discriminator_output_tensors)
|
||||
|
||||
self.toggle_optimizer(generator_optimizer)
|
||||
self.manual_backward(generator_loss)
|
||||
|
||||
if do_update:
|
||||
generator_optimizer.step()
|
||||
generator_optimizer.zero_grad()
|
||||
self.untoggle_optimizer(generator_optimizer)
|
||||
|
||||
self.toggle_optimizer(discriminator_optimizer)
|
||||
self.manual_backward(discriminator_loss)
|
||||
|
||||
if do_update:
|
||||
discriminator_optimizer.step()
|
||||
discriminator_optimizer.zero_grad()
|
||||
self.untoggle_optimizer(discriminator_optimizer)
|
||||
|
||||
if self.global_step % self.config_preview_frequency == 0:
|
||||
self.generate_preview(source_tensor, target_tensor, generator_output_tensor, generator_output_mask)
|
||||
|
||||
self.log('generator_loss', generator_loss, prog_bar = True)
|
||||
self.log('discriminator_loss', discriminator_loss, prog_bar = True)
|
||||
self.log('adversarial_loss', adversarial_loss)
|
||||
self.log('cycle_loss', cycle_loss)
|
||||
self.log('feature_loss', feature_loss)
|
||||
self.log('reconstruction_loss', reconstruction_loss)
|
||||
self.log('identity_loss', identity_loss)
|
||||
self.log('gaze_loss', gaze_loss)
|
||||
self.log('mask_loss', mask_loss)
|
||||
return generator_loss
|
||||
|
||||
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
source_tensor, target_tensor = batch
|
||||
source_embedding = calc_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0))
|
||||
output_tensor, _ = self.forward(source_embedding, target_tensor)
|
||||
output_embedding = calc_embedding(self.generator_embedder, output_tensor, (0, 0, 0, 0))
|
||||
validation_score = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5
|
||||
self.log('validation_score', validation_score, sync_dist = True, prog_bar = True)
|
||||
return validation_score
|
||||
|
||||
def generate_preview(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor, output_mask : Mask) -> None:
|
||||
preview_limit = 8
|
||||
preview_cells = []
|
||||
overlay_tensor = overlay_mask(output_tensor, output_mask)
|
||||
|
||||
for source_tensor, target_tensor, output_tensor, overlay_tensor in zip(source_tensor[:preview_limit], target_tensor[:preview_limit], output_tensor[:preview_limit], overlay_tensor[:preview_limit]):
|
||||
preview_cell = torch.cat([ source_tensor, target_tensor, output_tensor, overlay_tensor ], dim = 2)
|
||||
preview_cells.append(preview_cell)
|
||||
|
||||
preview_cells = torch.cat(preview_cells, dim = 1).unsqueeze(0)
|
||||
preview_grid = torchvision.utils.make_grid(preview_cells, normalize = True, scale_each = True)
|
||||
self.logger.experiment.add_image('preview', preview_grid, self.global_step) # type:ignore[attr-defined]
|
||||
|
||||
|
||||
def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]:
|
||||
config_batch_size = CONFIG_PARSER.getint('training.loader', 'batch_size')
|
||||
config_num_workers = CONFIG_PARSER.getint('training.loader', 'num_workers')
|
||||
|
||||
training_dataset, validate_dataset = split_dataset(dataset)
|
||||
training_loader = StatefulDataLoader(training_dataset, batch_size = config_batch_size, shuffle = True, num_workers = config_num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
validation_loader = StatefulDataLoader(validate_dataset, batch_size = config_batch_size, shuffle = False, num_workers = config_num_workers, pin_memory = True, persistent_workers = True)
|
||||
return training_loader, validation_loader
|
||||
|
||||
|
||||
def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[Tensor]]:
|
||||
config_split_ratio = CONFIG_PARSER.getfloat('training.loader', 'split_ratio')
|
||||
|
||||
dataset_size = len(dataset) # type:ignore[arg-type]
|
||||
training_size = int(dataset_size * config_split_ratio)
|
||||
validation_size = int(dataset_size - training_size)
|
||||
training_dataset, validate_dataset = random_split(dataset, [ training_size, validation_size ])
|
||||
return training_dataset, validate_dataset
|
||||
|
||||
|
||||
def prepare_datasets(config_parser : ConfigParser) -> List[Dataset[Tensor]]:
|
||||
datasets = []
|
||||
|
||||
for config_section in config_parser.sections():
|
||||
|
||||
if config_section.startswith('training.dataset'):
|
||||
current_config_parser = ConfigParser()
|
||||
current_config_parser.add_section('training.dataset')
|
||||
|
||||
for key, value in config_parser.items(config_section):
|
||||
current_config_parser.set('training.dataset', key, value)
|
||||
|
||||
datasets.append(DynamicDataset(current_config_parser))
|
||||
|
||||
return datasets
|
||||
|
||||
|
||||
def create_trainer() -> Trainer:
|
||||
config_max_epochs = CONFIG_PARSER.getint('training.trainer', 'max_epochs')
|
||||
config_strategy = CONFIG_PARSER.get('training.trainer', 'strategy')
|
||||
config_precision = CONFIG_PARSER.get('training.trainer', 'precision')
|
||||
config_logger_path = CONFIG_PARSER.get('training.trainer', 'logger_path')
|
||||
config_logger_name = CONFIG_PARSER.get('training.trainer', 'logger_name')
|
||||
config_directory_path = CONFIG_PARSER.get('training.output', 'directory_path')
|
||||
config_file_pattern = CONFIG_PARSER.get('training.output', 'file_pattern')
|
||||
logger = TensorBoardLogger(config_logger_path, config_logger_name)
|
||||
|
||||
return Trainer(
|
||||
logger = logger,
|
||||
log_every_n_steps = 10,
|
||||
max_epochs = config_max_epochs,
|
||||
strategy = config_strategy,
|
||||
precision = config_precision,
|
||||
callbacks =
|
||||
[
|
||||
ModelCheckpoint(
|
||||
monitor = 'generator_loss',
|
||||
dirpath = config_directory_path,
|
||||
filename = config_file_pattern,
|
||||
every_n_train_steps = 1000,
|
||||
save_top_k = 3,
|
||||
save_last = True
|
||||
)
|
||||
],
|
||||
val_check_interval = 1000
|
||||
)
|
||||
|
||||
|
||||
def train() -> None:
|
||||
config_resume_path = CONFIG_PARSER.get('training.output', 'resume_path')
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
dataset = ConcatDataset(prepare_datasets(CONFIG_PARSER))
|
||||
training_loader, validation_loader = create_loaders(dataset)
|
||||
face_swapper_trainer = FaceSwapperTrainer(CONFIG_PARSER)
|
||||
trainer = create_trainer()
|
||||
|
||||
if os.path.isfile(config_resume_path):
|
||||
trainer.fit(face_swapper_trainer, training_loader, validation_loader, ckpt_path = config_resume_path)
|
||||
else:
|
||||
trainer.fit(face_swapper_trainer, training_loader, validation_loader)
|
||||
@@ -0,0 +1,24 @@
|
||||
from typing import Any, Dict, Literal, Tuple, TypeAlias
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
||||
Batch : TypeAlias = Tuple[Tensor, Tensor]
|
||||
BatchMode = Literal['equal', 'same', 'different']
|
||||
|
||||
Feature : TypeAlias = Tensor
|
||||
Embedding : TypeAlias = Tensor
|
||||
Mask : TypeAlias = Tensor
|
||||
Loss : TypeAlias = Tensor
|
||||
|
||||
Padding : TypeAlias = Tuple[int, int, int, int]
|
||||
|
||||
GeneratorModule : TypeAlias = Module
|
||||
EmbedderModule : TypeAlias = Module
|
||||
GazerModule : TypeAlias = Module
|
||||
FaceMaskerModule : TypeAlias = Module
|
||||
|
||||
OptimizerSet : TypeAlias = Any
|
||||
|
||||
WarpTemplate = Literal['arcface_128_v2_to_arcface_112_v2', 'ffhq_to_arcface_128_v2', 'vgg_face_hq_to_arcface_128_v2']
|
||||
WarpTemplateSet : TypeAlias = Dict[WarpTemplate, Tensor]
|
||||
@@ -0,0 +1,57 @@
|
||||
from configparser import ConfigParser
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from face_swapper.src.networks.aad import AAD
|
||||
from face_swapper.src.networks.masknet import MaskNet
|
||||
from face_swapper.src.networks.unet import UNet
|
||||
|
||||
|
||||
@pytest.mark.parametrize('output_size', [ 128, 256, 512 ])
|
||||
def test_aad_with_unet(output_size : int) -> None:
|
||||
config_parser = ConfigParser()
|
||||
config_parser.read_dict(
|
||||
{
|
||||
'training.model.generator':
|
||||
{
|
||||
'source_channels': '512',
|
||||
'output_channels': str(output_size * 16),
|
||||
'output_size': str(output_size),
|
||||
'num_blocks': '2'
|
||||
}
|
||||
})
|
||||
|
||||
encoder = UNet(config_parser).eval()
|
||||
generator = AAD(config_parser).eval()
|
||||
|
||||
source_tensor = torch.randn(1, 512)
|
||||
target_tensor = torch.randn(1, 3, output_size, output_size)
|
||||
|
||||
target_features = encoder(target_tensor)
|
||||
output_tensor = generator(source_tensor, target_features)
|
||||
|
||||
assert output_tensor.shape == (1, 3, output_size, output_size)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('output_size', [ 128, 256, 512 ])
|
||||
def test_mask_net(output_size : int) -> None:
|
||||
config_parser = ConfigParser()
|
||||
config_parser.read_dict(
|
||||
{
|
||||
'training.model.masker':
|
||||
{
|
||||
'input_channels': '67',
|
||||
'output_channels': '1',
|
||||
'num_filters': '16'
|
||||
}
|
||||
})
|
||||
|
||||
masker = MaskNet(config_parser).eval()
|
||||
|
||||
target_tensor = torch.randn(1, 3, output_size, output_size)
|
||||
target_feature = torch.randn(1, 64, output_size, output_size)
|
||||
|
||||
output_mask = masker(target_tensor, target_feature)
|
||||
|
||||
assert output_mask.shape == (1, 1, output_size, output_size)
|
||||
@@ -0,0 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from src.training import train
|
||||
|
||||
if __name__ == '__main__':
|
||||
train()
|
||||
@@ -5,3 +5,4 @@ disallow_untyped_calls = True
|
||||
disallow_untyped_defs = True
|
||||
ignore_missing_imports = True
|
||||
strict_optional = False
|
||||
explicit_package_bases = True
|
||||
|
||||
+9
-5
@@ -1,6 +1,10 @@
|
||||
lightning==2.4.0
|
||||
numpy==1.26.4
|
||||
--extra-index-url https://download.pytorch.org/whl/cu124
|
||||
albumentations==2.0.5
|
||||
lightning==2.5.1
|
||||
onnx==1.17.0
|
||||
onnxruntime==1.20.0
|
||||
opencv-python==4.10.0.84
|
||||
mxnet==1.9.1
|
||||
onnxruntime==1.21.0
|
||||
pytorch-msssim==1.0.0
|
||||
torch==2.6.0
|
||||
torchdata==0.11.0
|
||||
torchvision==0.21.0
|
||||
tensorboard==2.19.0
|
||||
|
||||
Reference in New Issue
Block a user