This commit is contained in:
harisreedhar
2025-02-21 19:40:44 +05:30
committed by henryruhs
parent b33281425a
commit db44c91dd8
8 changed files with 75 additions and 69 deletions
+3 -4
View File
@@ -26,10 +26,8 @@ Setup
This `config.ini` utilizes the MegaFace dataset to train the Face Swapper model.
```
[preparing.dataset]
dataset_path = .datasets/vggface2
directory_pattern = {}/*
image_pattern = {}/*.*g
[training.dataset]
file_pattern = .datasets/vggface2/**/*.jpg
same_person_probability = 0.2
```
@@ -37,6 +35,7 @@ same_person_probability = 0.2
[training.loader]
batch_size = 8
num_workers = 8
split_ratio = 0.95
```
```
+3 -4
View File
@@ -1,12 +1,11 @@
[preparing.dataset]
dataset_path =
directory_pattern =
image_pattern =
[training.dataset]
file_pattern =
same_person_probability =
[training.loader]
batch_size =
num_workers =
split_ratio =
[training.model]
id_embedder_path =
+11 -30
View File
@@ -1,44 +1,30 @@
import glob
import os.path
import random
from typing import Tuple
import cv2
import torch
from torch import Tensor
from torch.utils.data import Dataset
from torchvision import transforms
from .types import Batch, ImagePathList, ImagePathSet
from .types import Batch
class DataLoader(Dataset[Tensor]):
def __init__(self, dataset_path : str, dataset_image_pattern : str, dataset_directory_pattern : str, same_person_probability : float) -> None:
def __init__(self, file_pattern : str, same_person_probability : float) -> None:
self.same_person_probability = same_person_probability
self.directory_paths = glob.glob(dataset_directory_pattern.format(dataset_path))
self.image_paths, self.image_path_set = self.prepare_image_paths(dataset_image_pattern)
self.file_paths = glob.glob(file_pattern)
self.transforms = self.compose_transforms()
def __getitem__(self, index : int) -> Batch:
source_image_path = self.image_paths[index]
def __getitem__(self, index : int) -> Batch: # type:ignore[override]
source_image_path = self.file_paths[index]
if random.random() > self.same_person_probability:
if random.random() < self.same_person_probability:
return self.prepare_same_person(source_image_path)
return self.prepare_different_person(source_image_path)
def __len__(self) -> int:
return len(self.image_paths)
#todo: remove this method - only use glob.glob in init()
def prepare_image_paths(self, dataset_image_pattern : str) -> Tuple[ImagePathList, ImagePathSet]:
image_paths = []
image_path_set = {}
for directory_path in self.directory_paths:
image_paths.extend(glob.glob(dataset_image_pattern.format(directory_path)))
image_path_set[directory_path] = image_paths
return image_paths, image_path_set
return len(self.file_paths)
@staticmethod
def compose_transforms() -> transforms:
@@ -54,20 +40,15 @@ class DataLoader(Dataset[Tensor]):
])
def prepare_different_person(self, source_image_path : str) -> Batch:
is_same_person = torch.tensor(0)
target_image_path = random.choice(self.image_paths)
target_image_path = random.choice(self.file_paths)
source_vision_frame = cv2.imread(source_image_path)
target_vision_frame = cv2.imread(target_image_path)
source_tensor = self.transforms(source_vision_frame)
target_tensor = self.transforms(target_vision_frame)
return source_tensor, target_tensor, is_same_person
return source_tensor, target_tensor
def prepare_same_person(self, source_image_path : str) -> Batch:
is_same_person = torch.tensor(1)
#todo: why not like in prepare_different_person
target_image_path = random.choice(self.image_path_set.get(os.path.dirname(source_image_path)))
source_vision_frame = cv2.imread(source_image_path)
target_vision_frame = cv2.imread(target_image_path)
source_tensor = self.transforms(source_vision_frame)
target_tensor = self.transforms(target_vision_frame)
return source_tensor, target_tensor, is_same_person
target_tensor = source_tensor.clone()
return source_tensor, target_tensor
+1 -1
View File
@@ -21,7 +21,7 @@ def export() -> None:
model = Generator()
model.load_state_dict(state_dict)
model.eval()
model.ir_version = ir_version
model.ir_version = torch.tensor(ir_version)
source_tensor = torch.randn(1, 512)
target_tensor = torch.randn(1, 3, 256, 256)
torch.onnx.export(model, (source_tensor, target_tensor), target_path, input_names = [ 'source', 'target' ], output_names = [ 'output' ], opset_version = opset_version)
+9 -1
View File
@@ -1,6 +1,7 @@
import numpy
import torch
from torch import Tensor, nn
from pytorch_msssim import ssim
from .types import EmbedderModule, Embedding, Padding, VisionFrame, VisionTensor
@@ -41,6 +42,13 @@ def calc_id_embedding(id_embedder : EmbedderModule, vision_tensor : VisionTensor
crop_vision_tensor[:, :, 112 - padding[1]:, :] = 0
crop_vision_tensor[:, :, :, :padding[2]] = 0
crop_vision_tensor[:, :, :, 112 - padding[3]:] = 0
source_embedding = id_embedder(crop_vision_tensor)
with torch.no_grad():
source_embedding = id_embedder(crop_vision_tensor)
source_embedding = nn.functional.normalize(source_embedding, p = 2)
return source_embedding
def calc_structural_similarity(swap_tensor : VisionTensor, target_tensor : VisionTensor) -> Tensor:
swap_data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))
structural_similarity = 1 - ssim(swap_tensor, target_tensor, data_range = swap_data_range).mean()
return structural_similarity
+19 -10
View File
@@ -2,10 +2,9 @@ import configparser
from typing import Tuple
import torch
from pytorch_msssim import ssim
from torch import Tensor, nn
from ..helper import calc_id_embedding, hinge_fake_loss, hinge_real_loss
from ..helper import calc_id_embedding, calc_structural_similarity, hinge_fake_loss, hinge_real_loss
from ..types import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor
CONFIG = configparser.ConfigParser()
@@ -27,7 +26,7 @@ class FaceSwapperLoss:
self.motion_extractor.eval()
def calc_generator_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes, swap_attributes : SwapAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> GeneratorLossSet:
source_tensor, target_tensor, is_same_person = batch
source_tensor, target_tensor = batch
weight_adversarial = CONFIG.getfloat('training.losses', 'weight_adversarial')
weight_identity = CONFIG.getfloat('training.losses', 'weight_identity')
weight_attribute = CONFIG.getfloat('training.losses', 'weight_attribute')
@@ -39,7 +38,7 @@ class FaceSwapperLoss:
'loss_adversarial': self.calc_adversarial_loss(discriminator_outputs),
'loss_identity': self.calc_identity_loss(source_tensor, swap_tensor),
'loss_attribute': self.calc_attribute_loss(target_attributes, swap_attributes),
'loss_reconstruction': self.calc_reconstruction_loss(swap_tensor, target_tensor, is_same_person)
'loss_reconstruction': self.calc_reconstruction_loss(source_tensor, target_tensor, swap_tensor)
}
if weight_pose > 0:
@@ -95,12 +94,22 @@ class FaceSwapperLoss:
loss_attribute = torch.stack(loss_attributes).mean() * 0.5
return loss_attribute
def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> LossTensor:
loss_reconstruction = torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1)
loss_reconstruction = torch.mean(loss_reconstruction, dim = 1) * 0.5
loss_reconstruction = torch.sum(loss_reconstruction * is_same_person) / (is_same_person.sum() + 1e-4)
loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean()
loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5
def calc_reconstruction_loss(self, source_tensor : VisionTensor, target_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor:
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0))
target_embedding = calc_id_embedding(self.id_embedder, target_tensor, (0, 0, 0, 0))
face_similarities = (torch.cosine_similarity(source_embedding, target_embedding) + 1) * 0.5
loss_reconstructions = []
for index, face_similarity in enumerate(face_similarities):
if face_similarity.item() > 0.9:
loss_mse = self.mse_loss(swap_tensor[index], target_tensor[index])
loss_ssim = calc_structural_similarity(swap_tensor[index].unsqueeze(0), target_tensor[index].unsqueeze(0))
loss_reconstruction = (loss_mse + loss_ssim) * 0.5
loss_reconstructions.append(loss_reconstruction)
else:
loss_reconstructions.append(torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype))
loss_reconstruction = torch.stack(loss_reconstructions).mean()
return loss_reconstruction
def calc_identity_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor:
+28 -15
View File
@@ -1,6 +1,6 @@
import configparser
import os
from typing import Tuple
from typing import Any, Tuple
import lightning
import torch
@@ -10,7 +10,7 @@ from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from torch import Tensor, nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader as TorchDataLoader, Subset
from torch.utils.data import DataLoader as TorchDataLoader, Dataset, random_split
from .data_loader import DataLoader
from .helper import calc_id_embedding
@@ -42,7 +42,7 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
return generator_optimizer, discriminator_optimizer
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
source_tensor, target_tensor, is_same_person = batch
source_tensor, target_tensor = batch
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0))
swap_tensor = self.generator(source_embedding, target_tensor)
@@ -75,7 +75,7 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
return generator_losses.get('loss_generator')
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
source_tensor, target_tensor, _ = batch
source_tensor, target_tensor = batch
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0))
output_tensor = self.generator(source_embedding, target_tensor)
output_embedding = calc_id_embedding(self.id_embedder, output_tensor, (0, 0, 0, 0))
@@ -91,7 +91,25 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
preview_items.append(torch.cat([ source_tensor, target_tensor, output_tensor] , dim = 2))
preview_grid = torchvision.utils.make_grid(torch.cat(preview_items, dim = 1).unsqueeze(0), normalize = True, scale_each = True)
self.logger.experiment.add_image('preview', preview_grid, self.global_step)
self.logger.experiment.add_image('preview', preview_grid, self.global_step) # type:ignore[attr-defined]
def create_loaders(dataset : Dataset[Any]): # type:ignore[type-arg]
batch_size = CONFIG.getint('training.loader', 'batch_size')
num_workers = CONFIG.getint('training.loader', 'num_workers')
training_dataset, validate_dataset = split_dataset(dataset)
training_loader = TorchDataLoader(training_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
validation_loader = TorchDataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, drop_last = False, pin_memory = True, persistent_workers = True)
return training_loader, validation_loader # type:ignore[return-value]
def split_dataset(dataset : Dataset[Any]) -> Tuple[Dataset[Any], Dataset[Any]]:
loader_split_ratio = CONFIG.getfloat('training.loader', 'split_ratio')
training_size = int(loader_split_ratio * len(dataset)) # type:ignore[operator, arg-type]
validation_size = len(dataset) - training_size # type:ignore[arg-type]
training_dataset, validate_dataset = random_split(dataset, [training_size, validation_size])
return training_dataset, validate_dataset
def create_trainer() -> Trainer:
@@ -106,7 +124,7 @@ def create_trainer() -> Trainer:
logger = logger,
log_every_n_steps = 10,
max_epochs = trainer_max_epochs,
precision = trainer_precision,
precision = trainer_precision, # type:ignore[arg-type]
callbacks =
[
ModelCheckpoint(
@@ -123,17 +141,12 @@ def create_trainer() -> Trainer:
def train() -> None:
dataset_path = CONFIG.get('preparing.dataset', 'dataset_path')
dataset_image_pattern = CONFIG.get('preparing.dataset', 'image_pattern')
dataset_directory_pattern = CONFIG.get('preparing.dataset', 'directory_pattern')
same_person_probability = CONFIG.getfloat('preparing.dataset', 'same_person_probability')
batch_size = CONFIG.getint('training.loader', 'batch_size')
num_workers = CONFIG.getint('training.loader', 'num_workers')
dataset_file_pattern = CONFIG.get('training.dataset', 'file_pattern')
same_person_probability = CONFIG.getfloat('training.dataset', 'same_person_probability')
output_resume_path = CONFIG.get('training.output', 'resume_path')
dataset = DataLoader(dataset_path, dataset_image_pattern, dataset_directory_pattern, same_person_probability)
training_loader = TorchDataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
validation_loader = TorchDataLoader(Subset(dataset, range(1000)), batch_size = batch_size, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
dataset = DataLoader(dataset_file_pattern, same_person_probability)
training_loader, validation_loader = create_loaders(dataset)
face_swapper_trainer = FaceSwapperTrainer()
trainer = create_trainer()
+1 -4
View File
@@ -5,10 +5,7 @@ from numpy.typing import NDArray
from torch import Tensor
from torch.nn import Module
Batch : TypeAlias = Tuple[Tensor, Tensor, Tensor]
ImagePathList : TypeAlias = List[str]
ImagePathSet : TypeAlias = Dict[str, ImagePathList]
Batch : TypeAlias = Tuple[Tensor, Tensor]
SwapAttributes : TypeAlias = Tuple[Tensor, ...]
TargetAttributes : TypeAlias = Tuple[Tensor, ...]