diff --git a/face_swapper/README.md b/face_swapper/README.md index 7daa0a3..0962162 100644 --- a/face_swapper/README.md +++ b/face_swapper/README.md @@ -28,6 +28,7 @@ This `config.ini` utilizes the MegaFace dataset to train the Face Swapper model. ``` [training.dataset] file_pattern = .datasets/vggface2/**/*.jpg +warp_matrix = vgg_face_hq_to_arcface_128_v2 batch_ratio = 0.2 ``` diff --git a/face_swapper/config.ini b/face_swapper/config.ini index c3ce49b..63aa355 100644 --- a/face_swapper/config.ini +++ b/face_swapper/config.ini @@ -1,5 +1,6 @@ [training.dataset] file_pattern = +warp_matrix = batch_ratio = [training.loader] diff --git a/face_swapper/src/dataset.py b/face_swapper/src/dataset.py index b2d752b..7bcde99 100644 --- a/face_swapper/src/dataset.py +++ b/face_swapper/src/dataset.py @@ -7,13 +7,14 @@ from torch.utils.data import Dataset from torchvision import io, transforms from .helper import warp_tensor -from .types import Batch +from .types import Batch, WarpMatrix class DynamicDataset(Dataset[Tensor]): - def __init__(self, file_pattern : str, batch_ratio : float) -> None: + def __init__(self, file_pattern : str, warp_matrix : WarpMatrix, batch_ratio : float) -> None: self.file_paths = glob.glob(file_pattern) self.transforms = self.compose_transforms() + self.warp_matrix = warp_matrix self.batch_ratio = batch_ratio def __getitem__(self, index : int) -> Batch: @@ -27,8 +28,7 @@ class DynamicDataset(Dataset[Tensor]): def __len__(self) -> int: return len(self.file_paths) - @staticmethod - def compose_transforms() -> transforms: + def compose_transforms(self) -> transforms: return transforms.Compose( [ transforms.ToPILImage(), @@ -36,10 +36,13 @@ class DynamicDataset(Dataset[Tensor]): transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1), transforms.RandomAffine(4, translate = (0.01, 0.01), scale = (0.98, 1.02), shear = (1, 1)), transforms.ToTensor(), - transforms.Lambda(lambda temp_tensor: warp_tensor(temp_tensor.unsqueeze(0), 'vgg_face_hq_to_arcface_128_v2').squeeze(0)), + transforms.Lambda(self.warp_tensor), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) + def warp_tensor(self, temp_tensor : Tensor): + return warp_tensor(temp_tensor.unsqueeze(0), self.warp_matrix).squeeze(0) + def prepare_different_batch(self, source_path : str) -> Batch: target_path = random.choice(self.file_paths) source_tensor = io.read_image(source_path) diff --git a/face_swapper/src/helper.py b/face_swapper/src/helper.py index 765fe6a..eb8f059 100644 --- a/face_swapper/src/helper.py +++ b/face_swapper/src/helper.py @@ -1,7 +1,7 @@ import torch from torch import Tensor, nn -from .types import EmbedderModule, Embedding, Padding, WarpMatrixSet +from .types import EmbedderModule, Embedding, Padding, WarpMatrix, WarpMatrixSet WARP_MATRIX_SET : WarpMatrixSet =\ { @@ -18,9 +18,9 @@ WARP_MATRIX_SET : WarpMatrixSet =\ } -def warp_tensor(input_tensor : Tensor, alignment_matrix : str) -> Tensor: - matrix = WARP_MATRIX_SET.get(alignment_matrix).repeat(input_tensor.shape[0], 1, 1) - grid = nn.functional.affine_grid(matrix.to(input_tensor.device), list(input_tensor.shape)) +def warp_tensor(input_tensor : Tensor, warp_matrix : WarpMatrix) -> Tensor: + warp_matrix = WARP_MATRIX_SET.get(warp_matrix).repeat(input_tensor.shape[0], 1, 1) + grid = nn.functional.affine_grid(warp_matrix.to(input_tensor.device), list(input_tensor.shape)) output_tensor = nn.functional.grid_sample(input_tensor, grid, align_corners = False, padding_mode = 'reflection') return output_tensor diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index abaf6d1..3e88f72 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -62,7 +62,7 @@ class FaceSwapperTrainer(lightning.LightningModule): 'lr_scheduler': { 'scheduler': generator_scheduler, - 'interval': 'step', + 'interval': 'step' } } discriminator_config =\ @@ -71,7 +71,7 @@ class FaceSwapperTrainer(lightning.LightningModule): 'lr_scheduler': { 'scheduler': discriminator_scheduler, - 'interval': 'step', + 'interval': 'step' } } return generator_config, discriminator_config @@ -194,13 +194,14 @@ def create_trainer() -> Trainer: def train() -> None: dataset_file_pattern = CONFIG.get('training.dataset', 'file_pattern') + dataset_warp_matrix = CONFIG.get('training.dataset', 'warp_matrix') dataset_batch_ratio = CONFIG.getfloat('training.dataset', 'batch_ratio') output_resume_path = CONFIG.get('training.output', 'resume_path') if torch.cuda.is_available(): torch.set_float32_matmul_precision('high') - dataset = DynamicDataset(dataset_file_pattern, dataset_batch_ratio) + dataset = DynamicDataset(dataset_file_pattern, dataset_warp_matrix, dataset_batch_ratio) training_loader, validation_loader = create_loaders(dataset) face_swapper_trainer = FaceSwapperTrainer() trainer = create_trainer() diff --git a/face_swapper/src/types.py b/face_swapper/src/types.py index 440ddc2..3190341 100644 --- a/face_swapper/src/types.py +++ b/face_swapper/src/types.py @@ -1,5 +1,6 @@ from typing import Any, Dict, Tuple, TypeAlias +from jinja2.nodes import Literal from torch import Tensor from torch.nn import Module @@ -18,4 +19,5 @@ MotionExtractorModule : TypeAlias = Module OptimizerConfig : TypeAlias = Any -WarpMatrixSet : TypeAlias = Dict[str, Tensor] +WarpMatrix = Literal['vgg_face_hq_to_arcface_128_v2', 'arcface_128_v2_to_arcface_112_v2'] +WarpMatrixSet : TypeAlias = Dict[WarpMatrix, Tensor]