diff --git a/face_swapper/README.md b/face_swapper/README.md index 2270858..26f454d 100644 --- a/face_swapper/README.md +++ b/face_swapper/README.md @@ -26,10 +26,8 @@ Setup This `config.ini` utilizes the MegaFace dataset to train the Face Swapper model. ``` -[preparing.dataset] -dataset_path = .datasets/vggface2 -directory_pattern = {}/* -image_pattern = {}/*.*g +[training.dataset] +file_pattern = .datasets/vggface2/**/*.jpg same_person_probability = 0.2 ``` @@ -37,6 +35,7 @@ same_person_probability = 0.2 [training.loader] batch_size = 8 num_workers = 8 +split_ratio = 0.95 ``` ``` diff --git a/face_swapper/config.ini b/face_swapper/config.ini index 6b835fe..fbc18eb 100644 --- a/face_swapper/config.ini +++ b/face_swapper/config.ini @@ -1,12 +1,11 @@ -[preparing.dataset] -dataset_path = -directory_pattern = -image_pattern = +[training.dataset] +file_pattern = same_person_probability = [training.loader] batch_size = num_workers = +split_ratio = [training.model] id_embedder_path = diff --git a/face_swapper/src/data_loader.py b/face_swapper/src/data_loader.py index d1145e1..7410c09 100644 --- a/face_swapper/src/data_loader.py +++ b/face_swapper/src/data_loader.py @@ -1,44 +1,30 @@ import glob -import os.path import random -from typing import Tuple import cv2 -import torch from torch import Tensor from torch.utils.data import Dataset from torchvision import transforms -from .types import Batch, ImagePathList, ImagePathSet +from .types import Batch class DataLoader(Dataset[Tensor]): - def __init__(self, dataset_path : str, dataset_image_pattern : str, dataset_directory_pattern : str, same_person_probability : float) -> None: + def __init__(self, file_pattern : str, same_person_probability : float) -> None: self.same_person_probability = same_person_probability - self.directory_paths = glob.glob(dataset_directory_pattern.format(dataset_path)) - self.image_paths, self.image_path_set = self.prepare_image_paths(dataset_image_pattern) + self.file_paths = glob.glob(file_pattern) self.transforms = self.compose_transforms() - def __getitem__(self, index : int) -> Batch: - source_image_path = self.image_paths[index] + def __getitem__(self, index : int) -> Batch: # type:ignore[override] + source_image_path = self.file_paths[index] - if random.random() > self.same_person_probability: + if random.random() < self.same_person_probability: return self.prepare_same_person(source_image_path) return self.prepare_different_person(source_image_path) def __len__(self) -> int: - return len(self.image_paths) - - #todo: remove this method - only use glob.glob in init() - def prepare_image_paths(self, dataset_image_pattern : str) -> Tuple[ImagePathList, ImagePathSet]: - image_paths = [] - image_path_set = {} - - for directory_path in self.directory_paths: - image_paths.extend(glob.glob(dataset_image_pattern.format(directory_path))) - image_path_set[directory_path] = image_paths - return image_paths, image_path_set + return len(self.file_paths) @staticmethod def compose_transforms() -> transforms: @@ -54,20 +40,15 @@ class DataLoader(Dataset[Tensor]): ]) def prepare_different_person(self, source_image_path : str) -> Batch: - is_same_person = torch.tensor(0) - target_image_path = random.choice(self.image_paths) + target_image_path = random.choice(self.file_paths) source_vision_frame = cv2.imread(source_image_path) target_vision_frame = cv2.imread(target_image_path) source_tensor = self.transforms(source_vision_frame) target_tensor = self.transforms(target_vision_frame) - return source_tensor, target_tensor, is_same_person + return source_tensor, target_tensor def prepare_same_person(self, source_image_path : str) -> Batch: - is_same_person = torch.tensor(1) - #todo: why not like in prepare_different_person - target_image_path = random.choice(self.image_path_set.get(os.path.dirname(source_image_path))) source_vision_frame = cv2.imread(source_image_path) - target_vision_frame = cv2.imread(target_image_path) source_tensor = self.transforms(source_vision_frame) - target_tensor = self.transforms(target_vision_frame) - return source_tensor, target_tensor, is_same_person + target_tensor = source_tensor.clone() + return source_tensor, target_tensor diff --git a/face_swapper/src/exporting.py b/face_swapper/src/exporting.py index 3b0892d..6631fe1 100644 --- a/face_swapper/src/exporting.py +++ b/face_swapper/src/exporting.py @@ -21,7 +21,7 @@ def export() -> None: model = Generator() model.load_state_dict(state_dict) model.eval() - model.ir_version = ir_version + model.ir_version = torch.tensor(ir_version) source_tensor = torch.randn(1, 512) target_tensor = torch.randn(1, 3, 256, 256) torch.onnx.export(model, (source_tensor, target_tensor), target_path, input_names = [ 'source', 'target' ], output_names = [ 'output' ], opset_version = opset_version) diff --git a/face_swapper/src/helper.py b/face_swapper/src/helper.py index cd9a41a..8770e97 100644 --- a/face_swapper/src/helper.py +++ b/face_swapper/src/helper.py @@ -1,6 +1,7 @@ import numpy import torch from torch import Tensor, nn +from pytorch_msssim import ssim from .types import EmbedderModule, Embedding, Padding, VisionFrame, VisionTensor @@ -41,6 +42,13 @@ def calc_id_embedding(id_embedder : EmbedderModule, vision_tensor : VisionTensor crop_vision_tensor[:, :, 112 - padding[1]:, :] = 0 crop_vision_tensor[:, :, :, :padding[2]] = 0 crop_vision_tensor[:, :, :, 112 - padding[3]:] = 0 - source_embedding = id_embedder(crop_vision_tensor) + with torch.no_grad(): + source_embedding = id_embedder(crop_vision_tensor) source_embedding = nn.functional.normalize(source_embedding, p = 2) return source_embedding + + +def calc_structural_similarity(swap_tensor : VisionTensor, target_tensor : VisionTensor) -> Tensor: + swap_data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor)) + structural_similarity = 1 - ssim(swap_tensor, target_tensor, data_range = swap_data_range).mean() + return structural_similarity diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 02196c7..0d5c49d 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -2,10 +2,9 @@ import configparser from typing import Tuple import torch -from pytorch_msssim import ssim from torch import Tensor, nn -from ..helper import calc_id_embedding, hinge_fake_loss, hinge_real_loss +from ..helper import calc_id_embedding, calc_structural_similarity, hinge_fake_loss, hinge_real_loss from ..types import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor CONFIG = configparser.ConfigParser() @@ -27,7 +26,7 @@ class FaceSwapperLoss: self.motion_extractor.eval() def calc_generator_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes, swap_attributes : SwapAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> GeneratorLossSet: - source_tensor, target_tensor, is_same_person = batch + source_tensor, target_tensor = batch weight_adversarial = CONFIG.getfloat('training.losses', 'weight_adversarial') weight_identity = CONFIG.getfloat('training.losses', 'weight_identity') weight_attribute = CONFIG.getfloat('training.losses', 'weight_attribute') @@ -39,7 +38,7 @@ class FaceSwapperLoss: 'loss_adversarial': self.calc_adversarial_loss(discriminator_outputs), 'loss_identity': self.calc_identity_loss(source_tensor, swap_tensor), 'loss_attribute': self.calc_attribute_loss(target_attributes, swap_attributes), - 'loss_reconstruction': self.calc_reconstruction_loss(swap_tensor, target_tensor, is_same_person) + 'loss_reconstruction': self.calc_reconstruction_loss(source_tensor, target_tensor, swap_tensor) } if weight_pose > 0: @@ -95,12 +94,22 @@ class FaceSwapperLoss: loss_attribute = torch.stack(loss_attributes).mean() * 0.5 return loss_attribute - def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> LossTensor: - loss_reconstruction = torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1) - loss_reconstruction = torch.mean(loss_reconstruction, dim = 1) * 0.5 - loss_reconstruction = torch.sum(loss_reconstruction * is_same_person) / (is_same_person.sum() + 1e-4) - loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean() - loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5 + def calc_reconstruction_loss(self, source_tensor : VisionTensor, target_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor: + source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0)) + target_embedding = calc_id_embedding(self.id_embedder, target_tensor, (0, 0, 0, 0)) + face_similarities = (torch.cosine_similarity(source_embedding, target_embedding) + 1) * 0.5 + loss_reconstructions = [] + + for index, face_similarity in enumerate(face_similarities): + if face_similarity.item() > 0.9: + loss_mse = self.mse_loss(swap_tensor[index], target_tensor[index]) + loss_ssim = calc_structural_similarity(swap_tensor[index].unsqueeze(0), target_tensor[index].unsqueeze(0)) + loss_reconstruction = (loss_mse + loss_ssim) * 0.5 + loss_reconstructions.append(loss_reconstruction) + else: + loss_reconstructions.append(torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)) + + loss_reconstruction = torch.stack(loss_reconstructions).mean() return loss_reconstruction def calc_identity_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor: diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 0a33709..f47093d 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -1,6 +1,6 @@ import configparser import os -from typing import Tuple +from typing import Any, Tuple import lightning import torch @@ -10,7 +10,7 @@ from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.loggers import TensorBoardLogger from torch import Tensor, nn from torch.optim import Optimizer -from torch.utils.data import DataLoader as TorchDataLoader, Subset +from torch.utils.data import DataLoader as TorchDataLoader, Dataset, random_split from .data_loader import DataLoader from .helper import calc_id_embedding @@ -42,7 +42,7 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): return generator_optimizer, discriminator_optimizer def training_step(self, batch : Batch, batch_index : int) -> Tensor: - source_tensor, target_tensor, is_same_person = batch + source_tensor, target_tensor = batch generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined] source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0)) swap_tensor = self.generator(source_embedding, target_tensor) @@ -75,7 +75,7 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): return generator_losses.get('loss_generator') def validation_step(self, batch : Batch, batch_index : int) -> Tensor: - source_tensor, target_tensor, _ = batch + source_tensor, target_tensor = batch source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0)) output_tensor = self.generator(source_embedding, target_tensor) output_embedding = calc_id_embedding(self.id_embedder, output_tensor, (0, 0, 0, 0)) @@ -91,7 +91,25 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): preview_items.append(torch.cat([ source_tensor, target_tensor, output_tensor] , dim = 2)) preview_grid = torchvision.utils.make_grid(torch.cat(preview_items, dim = 1).unsqueeze(0), normalize = True, scale_each = True) - self.logger.experiment.add_image('preview', preview_grid, self.global_step) + self.logger.experiment.add_image('preview', preview_grid, self.global_step) # type:ignore[attr-defined] + + +def create_loaders(dataset : Dataset[Any]): # type:ignore[type-arg] + batch_size = CONFIG.getint('training.loader', 'batch_size') + num_workers = CONFIG.getint('training.loader', 'num_workers') + + training_dataset, validate_dataset = split_dataset(dataset) + training_loader = TorchDataLoader(training_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True) + validation_loader = TorchDataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, drop_last = False, pin_memory = True, persistent_workers = True) + return training_loader, validation_loader # type:ignore[return-value] + + +def split_dataset(dataset : Dataset[Any]) -> Tuple[Dataset[Any], Dataset[Any]]: + loader_split_ratio = CONFIG.getfloat('training.loader', 'split_ratio') + training_size = int(loader_split_ratio * len(dataset)) # type:ignore[operator, arg-type] + validation_size = len(dataset) - training_size # type:ignore[arg-type] + training_dataset, validate_dataset = random_split(dataset, [training_size, validation_size]) + return training_dataset, validate_dataset def create_trainer() -> Trainer: @@ -106,7 +124,7 @@ def create_trainer() -> Trainer: logger = logger, log_every_n_steps = 10, max_epochs = trainer_max_epochs, - precision = trainer_precision, + precision = trainer_precision, # type:ignore[arg-type] callbacks = [ ModelCheckpoint( @@ -123,17 +141,12 @@ def create_trainer() -> Trainer: def train() -> None: - dataset_path = CONFIG.get('preparing.dataset', 'dataset_path') - dataset_image_pattern = CONFIG.get('preparing.dataset', 'image_pattern') - dataset_directory_pattern = CONFIG.get('preparing.dataset', 'directory_pattern') - same_person_probability = CONFIG.getfloat('preparing.dataset', 'same_person_probability') - batch_size = CONFIG.getint('training.loader', 'batch_size') - num_workers = CONFIG.getint('training.loader', 'num_workers') + dataset_file_pattern = CONFIG.get('training.dataset', 'file_pattern') + same_person_probability = CONFIG.getfloat('training.dataset', 'same_person_probability') output_resume_path = CONFIG.get('training.output', 'resume_path') - dataset = DataLoader(dataset_path, dataset_image_pattern, dataset_directory_pattern, same_person_probability) - training_loader = TorchDataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True) - validation_loader = TorchDataLoader(Subset(dataset, range(1000)), batch_size = batch_size, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True) + dataset = DataLoader(dataset_file_pattern, same_person_probability) + training_loader, validation_loader = create_loaders(dataset) face_swapper_trainer = FaceSwapperTrainer() trainer = create_trainer() diff --git a/face_swapper/src/types.py b/face_swapper/src/types.py index bf2fbff..83d6e6e 100644 --- a/face_swapper/src/types.py +++ b/face_swapper/src/types.py @@ -5,10 +5,7 @@ from numpy.typing import NDArray from torch import Tensor from torch.nn import Module -Batch : TypeAlias = Tuple[Tensor, Tensor, Tensor] - -ImagePathList : TypeAlias = List[str] -ImagePathSet : TypeAlias = Dict[str, ImagePathList] +Batch : TypeAlias = Tuple[Tensor, Tensor] SwapAttributes : TypeAlias = Tuple[Tensor, ...] TargetAttributes : TypeAlias = Tuple[Tensor, ...]